use hpt_common::{shape::shape::Shape, strides::strides::Strides};
use std::sync::Arc;
use crate::iterator_traits::{IterGetSet, StridedIterator, StridedIteratorMap, StridedIteratorZip};
pub mod strided_zip_simd {
use hpt_common::{shape::shape::Shape, strides::strides::Strides};
use crate::iterator_traits::{IterGetSetSimd, StridedIteratorSimd, StridedSimdIteratorZip};
use std::sync::Arc;
#[derive(Clone)]
pub struct StridedZipSimd<'a, A: 'a, B: 'a> {
pub(crate) a: A,
pub(crate) b: B,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, A, B> IterGetSetSimd for StridedZipSimd<'a, A, B>
where
A: IterGetSetSimd,
B: IterGetSetSimd,
{
type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
type SimdItem = (
<A as IterGetSetSimd>::SimdItem,
<B as IterGetSetSimd>::SimdItem,
);
fn set_end_index(&mut self, _: usize) {
panic!("single thread strided zip does not support set_intervals");
}
fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
panic!("single thread strided zip does not support set_intervals");
}
fn set_strides(&mut self, last_stride: Strides) {
self.a.set_strides(last_stride.clone());
self.b.set_strides(last_stride);
}
fn set_shape(&mut self, shape: Shape) {
self.a.set_shape(shape.clone());
self.b.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.a.set_prg(prg.clone());
self.b.set_prg(prg);
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
panic!("single thread strided zip does not support intervals");
}
fn strides(&self) -> &Strides {
self.a.strides()
}
fn shape(&self) -> &Shape {
self.a.shape()
}
fn layout(&self) -> &hpt_common::layout::layout::Layout {
self.a.layout()
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
self.a.broadcast_set_strides(shape);
self.b.broadcast_set_strides(shape);
}
fn outer_loop_size(&self) -> usize {
self.a.outer_loop_size()
}
fn inner_loop_size(&self) -> usize {
self.a.inner_loop_size()
}
fn next(&mut self) {
self.a.next();
self.b.next();
}
fn next_simd(&mut self) {
todo!()
}
#[inline(always)]
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
(self.a.inner_loop_next(index), self.b.inner_loop_next(index))
}
fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
(
self.a.inner_loop_next_simd(index),
self.b.inner_loop_next_simd(index),
)
}
fn all_last_stride_one(&self) -> bool {
self.a.all_last_stride_one() && self.b.all_last_stride_one()
}
fn lanes(&self) -> Option<usize> {
match (self.a.lanes(), self.b.lanes()) {
(Some(a), Some(b)) => {
if a == b {
Some(a)
} else {
None
}
}
_ => None,
}
}
}
impl<'a, A, B> StridedZipSimd<'a, A, B>
where
A: 'a + IterGetSetSimd,
B: 'a + IterGetSetSimd,
<A as IterGetSetSimd>::Item: Send,
<B as IterGetSetSimd>::Item: Send,
{
pub fn new(a: A, b: B) -> Self {
StridedZipSimd {
a,
b,
phantom: std::marker::PhantomData,
}
}
}
impl<'a, A, B> StridedIteratorSimd for StridedZipSimd<'a, A, B>
where
A: IterGetSetSimd,
B: IterGetSetSimd,
{
}
impl<'a, A, B> StridedSimdIteratorZip for StridedZipSimd<'a, A, B>
where
A: IterGetSetSimd,
B: IterGetSetSimd,
{
}
}
#[derive(Clone)]
pub struct StridedZip<'a, A: 'a, B: 'a> {
pub(crate) a: A,
pub(crate) b: B,
pub(crate) phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a, A, B> IterGetSet for StridedZip<'a, A, B>
where
A: IterGetSet,
B: IterGetSet,
{
type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
fn set_end_index(&mut self, _: usize) {
panic!("single thread strided zip does not support set_intervals");
}
fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
panic!("single thread strided zip does not support set_intervals");
}
fn set_strides(&mut self, last_stride: Strides) {
self.a.set_strides(last_stride.clone());
self.b.set_strides(last_stride);
}
fn set_shape(&mut self, shape: Shape) {
self.a.set_shape(shape.clone());
self.b.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.a.set_prg(prg.clone());
self.b.set_prg(prg);
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
panic!("single thread strided zip does not support intervals");
}
fn strides(&self) -> &Strides {
self.a.strides()
}
fn shape(&self) -> &Shape {
self.a.shape()
}
fn layout(&self) -> &hpt_common::layout::layout::Layout {
self.a.layout()
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
self.a.broadcast_set_strides(shape);
self.b.broadcast_set_strides(shape);
}
fn outer_loop_size(&self) -> usize {
self.a.outer_loop_size()
}
fn inner_loop_size(&self) -> usize {
self.a.inner_loop_size()
}
fn next(&mut self) {
self.a.next();
self.b.next();
}
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
(self.a.inner_loop_next(index), self.b.inner_loop_next(index))
}
}
impl<'a, A, B> StridedZip<'a, A, B>
where
A: 'a + IterGetSet,
B: 'a + IterGetSet,
<A as IterGetSet>::Item: Send,
<B as IterGetSet>::Item: Send,
{
pub fn new(a: A, b: B) -> Self {
StridedZip {
a,
b,
phantom: std::marker::PhantomData,
}
}
}
impl<'a, A, B> StridedIteratorZip for StridedZip<'a, A, B> {}
impl<'a, A, B> StridedIteratorMap for StridedZip<'a, A, B> {}
impl<'a, A, B> StridedIterator for StridedZip<'a, A, B>
where
A: IterGetSet,
B: IterGetSet,
{
}