hpt_iterator/
strided_mut.rs1use crate::{
2 iterator_traits::{IterGetSet, StridedIterator, StridedIteratorZip},
3 strided::Strided,
4 strided_zip::StridedZip,
5};
6use hpt_common::{shape::shape::Shape, shape::shape_utils::predict_broadcast_shape};
7use hpt_traits::tensor::{CommonBounds, TensorInfo};
8use std::sync::Arc;
9
10pub mod simd_imports {
12 use crate::{
13 iterator_traits::{IterGetSetSimd, StridedIteratorSimd, StridedSimdIteratorZip},
14 strided::strided_simd::StridedSimd,
15 };
16 use hpt_common::shape::shape::Shape;
17 use hpt_traits::{CommonBounds, TensorInfo};
18 use hpt_types::dtype::TypeCommon;
19 use hpt_types::vectors::traits::VecTrait;
20 use std::sync::Arc;
21
22 pub struct StridedMutSimd<'a, T: TypeCommon> {
26 pub(crate) base: StridedSimd<T>,
28 pub(crate) last_stride: i64,
30 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
32 }
33
34 impl<'a, T: CommonBounds> StridedMutSimd<'a, T> {
35 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
52 let base = StridedSimd::new(tensor);
53 let last_stride = base.last_stride;
54 StridedMutSimd {
55 base,
56 last_stride,
57 phantom: std::marker::PhantomData,
58 }
59 }
60 }
61
62 impl<'a, T: 'a> IterGetSetSimd for StridedMutSimd<'a, T>
63 where
64 T: CommonBounds,
65 {
66 type Item = &'a mut T;
67 type SimdItem = &'a mut T::Vec;
68
69 fn set_end_index(&mut self, end_index: usize) {
70 self.base.set_end_index(end_index);
71 }
72
73 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
74 self.base.set_intervals(intervals);
75 }
76
77 fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
78 self.base.set_strides(strides);
79 }
80
81 fn set_shape(&mut self, shape: Shape) {
82 self.base.set_shape(shape);
83 }
84
85 fn set_prg(&mut self, prg: Vec<i64>) {
86 self.base.set_prg(prg);
87 }
88
89 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
90 self.base.intervals()
91 }
92
93 fn strides(&self) -> &hpt_common::strides::strides::Strides {
94 self.base.strides()
95 }
96
97 fn shape(&self) -> &Shape {
98 self.base.shape()
99 }
100
101 fn layout(&self) -> &hpt_common::layout::layout::Layout {
102 self.base.layout()
103 }
104
105 fn broadcast_set_strides(&mut self, shape: &Shape) {
106 self.base.broadcast_set_strides(shape);
107 }
108
109 fn outer_loop_size(&self) -> usize {
110 self.base.outer_loop_size()
111 }
112 fn inner_loop_size(&self) -> usize {
113 self.base.inner_loop_size()
114 }
115
116 fn next(&mut self) {
117 self.base.next();
118 }
119 fn next_simd(&mut self) {
120 todo!()
121 }
122 #[inline(always)]
123 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
124 unsafe {
125 &mut *self
126 .base
127 .ptr
128 .ptr
129 .offset((index as isize) * (self.last_stride as isize))
130 }
131 }
132 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
133 let vector = unsafe { self.base.ptr.ptr.add(index * T::Vec::SIZE) };
134 unsafe { std::mem::transmute(vector) }
135 }
136 fn all_last_stride_one(&self) -> bool {
137 self.base.all_last_stride_one()
138 }
139
140 fn lanes(&self) -> Option<usize> {
141 self.base.lanes()
142 }
143 }
144 impl<'a, T> StridedIteratorSimd for StridedMutSimd<'a, T> where T: CommonBounds {}
145 impl<'a, T> StridedSimdIteratorZip for StridedMutSimd<'a, T> where T: CommonBounds {}
146}
147
148pub struct StridedMut<'a, T> {
152 pub(crate) base: Strided<T>,
154 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
156}
157
158impl<'a, T: CommonBounds> StridedMut<'a, T> {
159 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
169 StridedMut {
170 base: Strided::new(tensor),
171 phantom: std::marker::PhantomData,
172 }
173 }
174 #[track_caller]
196 pub fn zip<C>(mut self, mut other: C) -> StridedZip<'a, Self, C>
197 where
198 C: 'a + IterGetSet,
199 <C as IterGetSet>::Item: Send,
200 {
201 let new_shape = match predict_broadcast_shape(self.shape(), other.shape()) {
202 Ok(s) => s,
203 Err(err) => {
204 panic!("{}", err);
205 }
206 };
207
208 other.broadcast_set_strides(&new_shape);
209 self.broadcast_set_strides(&new_shape);
210
211 other.set_shape(new_shape.clone());
212 self.set_shape(new_shape.clone());
213
214 StridedZip::new(self, other)
215 }
216}
217
218impl<'a, T: CommonBounds> StridedIterator for StridedMut<'a, T> {}
219impl<'a, T: CommonBounds> StridedIteratorZip for StridedMut<'a, T> {}
220
221impl<'a, T: 'a> IterGetSet for StridedMut<'a, T>
222where
223 T: CommonBounds,
224{
225 type Item = &'a mut T;
226
227 fn set_end_index(&mut self, end_index: usize) {
228 self.base.set_end_index(end_index);
229 }
230
231 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
232 self.base.set_intervals(intervals);
233 }
234
235 fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
236 self.base.set_strides(strides);
237 }
238
239 fn set_shape(&mut self, shape: Shape) {
240 self.base.set_shape(shape);
241 }
242
243 fn set_prg(&mut self, prg: Vec<i64>) {
244 self.base.set_prg(prg);
245 }
246
247 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
248 self.base.intervals()
249 }
250
251 fn strides(&self) -> &hpt_common::strides::strides::Strides {
252 self.base.strides()
253 }
254
255 fn shape(&self) -> &Shape {
256 self.base.shape()
257 }
258
259 fn layout(&self) -> &hpt_common::layout::layout::Layout {
260 self.base.layout()
261 }
262
263 fn broadcast_set_strides(&mut self, shape: &Shape) {
264 self.base.broadcast_set_strides(shape);
265 }
266
267 fn outer_loop_size(&self) -> usize {
268 self.base.outer_loop_size()
269 }
270
271 fn inner_loop_size(&self) -> usize {
272 self.base.inner_loop_size()
273 }
274
275 fn next(&mut self) {
276 self.base.next();
277 }
278
279 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
280 unsafe {
281 self.base
282 .ptr
283 .get_ptr()
284 .add(index * (self.base.last_stride as usize))
285 .as_mut()
286 .unwrap()
287 }
288 }
289}