use crate::{
iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
par_strided::ParStrided,
shape_manipulate::{par_expand, par_reshape, par_transpose},
};
use hpt_common::shape::shape::Shape;
use hpt_traits::tensor::{CommonBounds, TensorInfo};
use rayon::iter::{
plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
ParallelIterator,
};
use std::sync::Arc;
pub mod par_strided_map_mut_simd {
use crate::{
iterator_traits::{IterGetSetSimd, ParStridedIteratorSimd, ParStridedIteratorSimdZip},
par_strided::par_strided_simd::ParStridedSimd,
};
use crate::{CommonBounds, TensorInfo};
use hpt_common::{shape::shape::Shape, utils::pointer::Pointer, utils::simd_ref::MutVec};
use hpt_types::dtype::TypeCommon;
use hpt_types::traits::VecTrait;
use rayon::iter::{
plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
ParallelIterator,
};
use std::sync::Arc;
pub struct ParStridedMutSimd<'a, T: TypeCommon + Send + Copy + Sync> {
pub(crate) base: ParStridedSimd<T>,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, T: CommonBounds> ParStridedMutSimd<'a, T> {
pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
ParStridedMutSimd {
base: ParStridedSimd::new(tensor),
phantom: std::marker::PhantomData,
}
}
}
impl<'a, T: CommonBounds> ParStridedIteratorSimdZip for ParStridedMutSimd<'a, T> {}
impl<'a, T: CommonBounds> ParStridedIteratorSimd for ParStridedMutSimd<'a, T> {}
impl<'a, T> ParallelIterator for ParStridedMutSimd<'a, T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = &'a mut T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(self, consumer)
}
}
impl<'a, T> UnindexedProducer for ParStridedMutSimd<'a, T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = &'a mut T;
fn split(self) -> (Self, Option<Self>) {
let (a, b) = self.base.split();
(
ParStridedMutSimd {
base: a,
phantom: std::marker::PhantomData,
},
b.map(|x| ParStridedMutSimd {
base: x,
phantom: std::marker::PhantomData,
}),
)
}
fn fold_with<F>(self, folder: F) -> F
where
F: Folder<Self::Item>,
{
folder
}
}
impl<'a, T: 'a> IterGetSetSimd for ParStridedMutSimd<'a, T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = &'a mut T;
type SimdItem
= MutVec<'a, T::Vec>
where
Self: 'a;
fn set_end_index(&mut self, end_index: usize) {
self.base.set_end_index(end_index);
}
fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.base.set_intervals(intervals);
}
fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
self.base.set_strides(strides);
}
fn set_shape(&mut self, shape: Shape) {
self.base.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.base.set_prg(prg);
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
self.base.intervals()
}
fn strides(&self) -> &hpt_common::strides::strides::Strides {
self.base.strides()
}
fn shape(&self) -> &Shape {
self.base.shape()
}
fn layout(&self) -> &hpt_common::layout::layout::Layout {
self.base.layout()
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
self.base.broadcast_set_strides(shape);
}
fn outer_loop_size(&self) -> usize {
self.base.outer_loop_size()
}
fn inner_loop_size(&self) -> usize {
self.base.inner_loop_size()
}
fn next(&mut self) {
self.base.next();
}
fn next_simd(&mut self) {
self.base.next_simd();
}
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
unsafe {
self.base
.ptr
.get_ptr()
.add(index * (self.base.last_stride as usize))
.as_mut()
.unwrap()
}
}
#[inline(always)]
fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
unsafe {
let ptr = self.base.ptr.get_ptr().add(index * T::Vec::SIZE) as *mut T::Vec;
#[cfg(feature = "bound_check")]
return MutVec::new(Pointer::new(ptr, T::Vec::SIZE as i64));
#[cfg(not(feature = "bound_check"))]
return MutVec::new(Pointer::new(ptr));
}
}
fn all_last_stride_one(&self) -> bool {
self.base.all_last_stride_one()
}
fn lanes(&self) -> Option<usize> {
self.base.lanes()
}
}
}
pub struct ParStridedMut<'a, T> {
pub(crate) base: ParStrided<T>,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, T: CommonBounds> ParStridedHelper for ParStridedMut<'a, T> {
fn _set_last_strides(&mut self, stride: i64) {
self.base._set_last_strides(stride);
}
fn _set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
self.base._set_strides(strides);
}
fn _set_shape(&mut self, shape: Shape) {
self.base._set_shape(shape);
}
fn _layout(&self) -> &hpt_common::layout::layout::Layout {
self.base._layout()
}
fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.base._set_intervals(intervals);
}
fn _set_end_index(&mut self, end_index: usize) {
self.base._set_end_index(end_index);
}
}
impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMut<'a, T> {
fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
par_reshape(self, shape)
}
fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
par_transpose(self, axes)
}
fn expand<S: Into<Shape>>(self, shape: S) -> Self {
par_expand(self, shape)
}
}
impl<'a, T: CommonBounds> ParStridedIteratorZip for ParStridedMut<'a, T> {}
impl<'a, T: CommonBounds> ParStridedMut<'a, T> {
pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
ParStridedMut {
base: ParStrided::new(tensor),
phantom: std::marker::PhantomData,
}
}
}
impl<'a, T> ParallelIterator for ParStridedMut<'a, T>
where
T: CommonBounds,
{
type Item = &'a mut T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(self, consumer)
}
}
impl<'a, T> UnindexedProducer for ParStridedMut<'a, T>
where
T: CommonBounds,
{
type Item = &'a mut T;
fn split(self) -> (Self, Option<Self>) {
let (a, b) = self.base.split();
(
ParStridedMut {
base: a,
phantom: std::marker::PhantomData,
},
b.map(|x| ParStridedMut {
base: x,
phantom: std::marker::PhantomData,
}),
)
}
fn fold_with<F>(self, folder: F) -> F
where
F: Folder<Self::Item>,
{
folder
}
}
impl<'a, T: 'a> IterGetSet for ParStridedMut<'a, T>
where
T: CommonBounds,
{
type Item = &'a mut T;
fn set_end_index(&mut self, end_index: usize) {
self.base.set_end_index(end_index);
}
fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.base.set_intervals(intervals);
}
fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
self.base.set_strides(strides);
}
fn set_shape(&mut self, shape: Shape) {
self.base.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.base.set_prg(prg);
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
self.base.intervals()
}
fn strides(&self) -> &hpt_common::strides::strides::Strides {
self.base.strides()
}
fn shape(&self) -> &Shape {
self.base.shape()
}
fn layout(&self) -> &hpt_common::layout::layout::Layout {
self.base.layout()
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
self.base.broadcast_set_strides(shape);
}
fn outer_loop_size(&self) -> usize {
self.base.outer_loop_size()
}
fn inner_loop_size(&self) -> usize {
self.base.inner_loop_size()
}
fn next(&mut self) {
self.base.next();
}
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
unsafe {
self.base
.ptr
.get_ptr()
.add(index * (self.base.last_stride as usize))
.as_mut()
.unwrap()
}
}
}