1use crate::{
2 iterator_traits::{
3 IterGetSet, ShapeManipulator, StridedHelper, StridedIterator, StridedIteratorMap,
4 StridedIteratorZip,
5 },
6 shape_manipulate::{expand, reshape, transpose},
7};
8use hpt_common::{
9 axis::axis::Axis, layout::layout::Layout, shape::shape::Shape,
10 shape::shape_utils::try_pad_shape, strides::strides::Strides,
11 strides::strides_utils::preprocess_strides, utils::pointer::Pointer,
12};
13use hpt_traits::tensor::{CommonBounds, TensorInfo};
14use std::sync::Arc;
15
16pub mod strided_simd {
18 use crate::{CommonBounds, TensorInfo};
19 use hpt_common::{
20 axis::axis::Axis, layout::layout::Layout, shape::shape::Shape,
21 shape::shape_utils::try_pad_shape, strides::strides::Strides,
22 strides::strides_utils::preprocess_strides, utils::pointer::Pointer,
23 };
24 use hpt_types::dtype::TypeCommon;
25 use hpt_types::vectors::traits::VecTrait;
26 use std::sync::Arc;
27
28 use crate::iterator_traits::{
29 IterGetSetSimd, ShapeManipulator, StridedIteratorMap, StridedIteratorSimd,
30 StridedSimdIteratorZip,
31 };
32
33 use super::{expand, reshape, transpose, StridedHelper};
34
35 #[derive(Clone)]
37 pub struct StridedSimd<T: TypeCommon> {
38 pub(crate) ptr: Pointer<T>,
40 pub(crate) layout: Layout,
42 pub(crate) prg: Vec<i64>,
44 pub(crate) last_stride: i64,
46 }
47
48 impl<T: CommonBounds> StridedSimd<T> {
49 pub fn shape(&self) -> &Shape {
55 self.layout.shape()
56 }
57 pub fn strides(&self) -> &Strides {
63 self.layout.strides()
64 }
65 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
75 StridedSimd {
76 ptr: tensor.ptr(),
77 layout: tensor.layout().clone(),
78 prg: vec![],
79 last_stride: *tensor.strides().last().unwrap_or(&0),
80 }
81 }
82 }
83
84 impl<T: CommonBounds> IterGetSetSimd for StridedSimd<T> {
85 type Item = T;
86 type SimdItem = T::Vec;
87
88 fn set_end_index(&mut self, _: usize) {
89 panic!("single thread iterator does not support set_end_index");
90 }
91
92 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
93 panic!("single thread iterator does not support set_intervals");
94 }
95
96 fn set_strides(&mut self, strides: Strides) {
97 self.layout.set_strides(strides);
98 }
99
100 fn set_shape(&mut self, shape: Shape) {
101 self.layout.set_shape(shape);
102 }
103
104 fn set_prg(&mut self, prg: Vec<i64>) {
105 self.prg = prg;
106 }
107
108 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
109 panic!("single thread iterator does not support intervals");
110 }
111
112 fn strides(&self) -> &Strides {
113 self.layout.strides()
114 }
115
116 fn shape(&self) -> &Shape {
117 self.layout.shape()
118 }
119
120 fn layout(&self) -> &Layout {
121 &self.layout
122 }
123
124 fn broadcast_set_strides(&mut self, shape: &Shape) {
125 let self_shape = try_pad_shape(self.shape(), shape.len());
126 self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
127 self.last_stride = self.strides()[self.strides().len() - 1];
128 }
129
130 fn outer_loop_size(&self) -> usize {
131 (self.shape().iter().product::<i64>() as usize) / self.inner_loop_size()
132 }
133 fn inner_loop_size(&self) -> usize {
134 self.shape().last().unwrap().clone() as usize
135 }
136
137 fn next(&mut self) {
138 for j in (0..(self.shape().len() as i64) - 1).rev() {
139 let j = j as usize;
140 if self.prg[j] < self.shape()[j] - 1 {
141 self.prg[j] += 1;
142 self.ptr.offset(self.strides()[j]);
143 break;
144 } else {
145 self.prg[j] = 0;
146 self.ptr.offset(-self.strides()[j] * (self.shape()[j] - 1));
147 }
148 }
149 }
150 fn next_simd(&mut self) {
151 todo!()
152 }
153 #[inline(always)]
154 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
155 unsafe {
156 *self
157 .ptr
158 .ptr
159 .offset((index as isize) * (self.last_stride as isize))
160 }
161 }
162 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
163 unsafe { Self::SimdItem::from_ptr(self.ptr.get_ptr().add(index * T::Vec::SIZE)) }
164 }
165 fn all_last_stride_one(&self) -> bool {
166 self.last_stride == 1
167 }
168
169 fn lanes(&self) -> Option<usize> {
170 Some(T::Vec::SIZE)
171 }
172 }
173
174 impl<T: CommonBounds> ShapeManipulator for StridedSimd<T> {
175 #[track_caller]
176 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
177 reshape(self, shape)
178 }
179
180 fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
181 transpose(self, axes)
182 }
183
184 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
185 expand(self, shape)
186 }
187 }
188
189 impl<T: CommonBounds> StridedHelper for StridedSimd<T> {
190 fn _set_last_strides(&mut self, stride: i64) {
191 self.last_stride = stride;
192 }
193 fn _set_strides(&mut self, strides: Strides) {
194 self.layout.set_strides(strides);
195 }
196 fn _set_shape(&mut self, shape: Shape) {
197 self.layout.set_shape(shape);
198 }
199 fn _layout(&self) -> &Layout {
200 &self.layout
201 }
202 }
203 impl<T: CommonBounds> StridedIteratorMap for StridedSimd<T> {}
204 impl<T: CommonBounds> StridedSimdIteratorZip for StridedSimd<T> {}
205 impl<T> StridedIteratorSimd for StridedSimd<T> where T: CommonBounds {}
206}
207
208#[derive(Clone)]
210pub struct Strided<T> {
211 pub(crate) ptr: Pointer<T>,
213 pub(crate) layout: Layout,
215 pub(crate) prg: Vec<i64>,
217 pub(crate) last_stride: i64,
219}
220
221impl<T: CommonBounds> Strided<T> {
222 pub fn shape(&self) -> &Shape {
228 self.layout.shape()
229 }
230 pub fn strides(&self) -> &Strides {
236 self.layout.strides()
237 }
238 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
248 Strided {
249 ptr: tensor.ptr(),
250 layout: tensor.layout().clone(),
251 prg: vec![],
252 last_stride: *tensor.strides().last().unwrap_or(&0),
253 }
254 }
255}
256
257impl<T> StridedIteratorMap for Strided<T> {}
258impl<T> StridedIteratorZip for Strided<T> {}
259
260impl<T: CommonBounds> IterGetSet for Strided<T> {
261 type Item = T;
262
263 fn set_end_index(&mut self, _: usize) {
264 panic!("single thread iterator does not support set_end_index");
265 }
266
267 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
268 panic!("single thread iterator does not support set_intervals");
269 }
270
271 fn set_strides(&mut self, strides: Strides) {
272 self.layout.set_strides(strides);
273 }
274
275 fn set_shape(&mut self, shape: Shape) {
276 self.layout.set_shape(shape);
277 }
278
279 fn set_prg(&mut self, prg: Vec<i64>) {
280 self.prg = prg;
281 }
282
283 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
284 panic!("single thread iterator does not support intervals");
285 }
286 fn strides(&self) -> &Strides {
287 self.layout.strides()
288 }
289
290 fn shape(&self) -> &Shape {
291 self.layout.shape()
292 }
293
294 fn layout(&self) -> &Layout {
295 &self.layout
296 }
297
298 fn broadcast_set_strides(&mut self, shape: &Shape) {
299 let self_shape = try_pad_shape(self.shape(), shape.len());
300 self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
301 self.last_stride = self.strides()[self.strides().len() - 1];
302 }
303
304 fn outer_loop_size(&self) -> usize {
305 (self.shape().iter().product::<i64>() as usize) / self.inner_loop_size()
306 }
307
308 fn inner_loop_size(&self) -> usize {
309 self.shape().last().unwrap().clone() as usize
310 }
311
312 fn next(&mut self) {
313 for j in (0..(self.shape().len() as i64) - 1).rev() {
314 let j = j as usize;
315 if self.prg[j] < self.shape()[j] - 1 {
316 self.prg[j] += 1;
317 self.ptr.offset(self.strides()[j]);
318 break;
319 } else {
320 self.prg[j] = 0;
321 self.ptr.offset(-self.strides()[j] * (self.shape()[j] - 1));
322 }
323 }
324 }
325
326 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
327 unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
328 }
329}
330
331impl<T: CommonBounds> ShapeManipulator for Strided<T> {
332 #[track_caller]
333 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
334 reshape(self, shape)
335 }
336
337 fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
338 transpose(self, axes)
339 }
340
341 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
342 expand(self, shape)
343 }
344}
345
346impl<T: CommonBounds> StridedIterator for Strided<T> {}
347
348impl<T> StridedHelper for Strided<T> {
349 fn _set_last_strides(&mut self, stride: i64) {
350 self.last_stride = stride;
351 }
352 fn _set_strides(&mut self, strides: Strides) {
353 self.layout.set_strides(strides);
354 }
355 fn _set_shape(&mut self, shape: Shape) {
356 self.layout.set_shape(shape);
357 }
358 fn _layout(&self) -> &Layout {
359 &self.layout
360 }
361}