use crate::iterator_traits::StridedIterator;
use crate::{iterator_traits::IterGetSet, strided_map_mut::StridedMapMut};
use hpt_traits::ops::creation::TensorCreator;
use hpt_traits::tensor::{CommonBounds, TensorInfo};
pub mod strided_map_simd {
use crate::iterator_traits::StridedIteratorSimd;
use crate::{
iterator_traits::IterGetSetSimd, strided_map_mut::strided_map_mut_simd::StridedMapMutSimd,
};
use crate::{CommonBounds, TensorInfo};
use hpt_traits::ops::creation::TensorCreator;
use hpt_types::dtype::TypeCommon;
#[derive(Clone)]
pub struct StridedMapSimd<'a, I, T: 'a, F, F2>
where
I: 'a + IterGetSetSimd<Item = T>,
{
pub(crate) iter: I,
pub(crate) f: F,
pub(crate) f2: F2,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, I: 'a + IterGetSetSimd<Item = T>, T: 'a, F, F2> StridedMapSimd<'a, I, T, F, F2> {
pub fn collect<U>(self) -> U
where
F: Fn(T) -> U::Meta + Sync + Send + 'a,
F2: Fn(
<I as IterGetSetSimd>::SimdItem,
) -> <<U as TensorCreator>::Meta as TypeCommon>::Vec
+ Sync
+ Send
+ 'a,
U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
<I as IterGetSetSimd>::Item: Send,
<U as TensorCreator>::Meta: CommonBounds,
<<U as TensorCreator>::Meta as TypeCommon>::Vec: Send,
{
let res = U::empty(self.iter.shape().clone()).unwrap();
let strided_mut = StridedMapMutSimd::new(res.clone());
let zip = strided_mut.zip(self.iter);
zip.for_each(
|(x, y)| {
*x = (self.f)(y);
},
|(x, y)| {
*x = (self.f2)(y);
},
);
res
}
}
}
#[derive(Clone)]
pub struct StridedMap<'a, I, T: 'a, F>
where
I: 'a + IterGetSet<Item = T>,
{
pub(crate) iter: I,
pub(crate) f: F,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, I: 'a + IterGetSet<Item = T>, T: 'a, F> StridedMap<'a, I, T, F> {
pub fn collect<U>(self) -> U
where
F: Fn(T) -> U::Meta + Sync + Send + 'a,
U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
<I as IterGetSet>::Item: Send,
<U as TensorCreator>::Meta: CommonBounds,
{
let res = U::empty(self.iter.shape().clone()).unwrap();
let strided_mut = StridedMapMut::new(res.clone());
let zip = strided_mut.zip(self.iter);
zip.for_each(|(x, y)| {
*x = (self.f)(y);
});
res
}
}