hpt_iterator/
strided_map_mut.rs1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use hpt_traits::tensor::{CommonBounds, TensorInfo};
3use hpt_types::dtype::TypeCommon;
4use std::sync::Arc;
5
6use crate::{
7 iterator_traits::{IterGetSet, StridedIterator},
8 par_strided_mut::ParStridedMut,
9 strided_zip::StridedZip,
10};
11
12pub mod strided_map_mut_simd {
14 use std::sync::Arc;
15
16 use crate::{CommonBounds, TensorInfo};
17 use hpt_common::{shape::shape::Shape, strides::strides::Strides};
18 use hpt_types::dtype::TypeCommon;
19
20 use crate::{
21 iterator_traits::{IterGetSetSimd, StridedIteratorSimd},
22 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
23 strided_zip::strided_zip_simd::StridedZipSimd,
24 };
25
26 pub struct StridedMapMutSimd<'a, T>
31 where
32 T: Copy + TypeCommon + Send + Sync,
33 {
34 pub(crate) base: ParStridedMutSimd<'a, T>,
36 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
38 }
39 impl<'a, T> StridedMapMutSimd<'a, T>
40 where
41 T: CommonBounds,
42 T::Vec: Send,
43 {
44 pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
54 StridedMapMutSimd {
55 base: ParStridedMutSimd::new(res_tensor),
56 phantom: std::marker::PhantomData,
57 }
58 }
59 pub(crate) fn zip<C>(self, other: C) -> StridedZipSimd<'a, Self, C>
75 where
76 C: 'a + IterGetSetSimd,
77 <C as IterGetSetSimd>::Item: Send,
78 {
79 StridedZipSimd::new(self, other)
80 }
81 }
82 impl<'a, T> StridedIteratorSimd for StridedMapMutSimd<'a, T> where T: 'a + CommonBounds {}
83 impl<'a, T: 'a + CommonBounds> IterGetSetSimd for StridedMapMutSimd<'a, T>
84 where
85 T::Vec: Send,
86 {
87 type Item = &'a mut T;
88 type SimdItem = &'a mut T::Vec;
89
90 fn set_end_index(&mut self, _: usize) {}
91
92 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {}
93
94 fn set_strides(&mut self, strides: Strides) {
95 self.base.set_strides(strides);
96 }
97
98 fn set_shape(&mut self, shape: Shape) {
99 self.base.set_shape(shape);
100 }
101
102 fn set_prg(&mut self, prg: Vec<i64>) {
103 self.base.set_prg(prg);
104 }
105
106 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
107 self.base.intervals()
108 }
109
110 fn strides(&self) -> &Strides {
111 self.base.strides()
112 }
113
114 fn shape(&self) -> &Shape {
115 self.base.shape()
116 }
117
118 fn layout(&self) -> &hpt_common::layout::layout::Layout {
119 self.base.layout()
120 }
121
122 fn broadcast_set_strides(&mut self, shape: &Shape) {
123 self.base.broadcast_set_strides(shape);
124 }
125
126 fn outer_loop_size(&self) -> usize {
127 self.base.outer_loop_size()
128 }
129
130 fn inner_loop_size(&self) -> usize {
131 self.base.inner_loop_size()
132 }
133
134 fn next(&mut self) {
135 self.base.next();
136 }
137
138 fn next_simd(&mut self) {
139 todo!()
140 }
141
142 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
143 self.base.inner_loop_next(index)
144 }
145
146 fn inner_loop_next_simd(&mut self, _: usize) -> Self::SimdItem {
147 todo!()
148 }
149
150 fn all_last_stride_one(&self) -> bool {
151 todo!()
152 }
153
154 fn lanes(&self) -> Option<usize> {
155 todo!()
156 }
157 }
158}
159
160pub struct StridedMapMut<'a, T>
164where
165 T: Copy + TypeCommon,
166{
167 pub(crate) base: ParStridedMut<'a, T>,
169 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
171}
172
173impl<'a, T> StridedMapMut<'a, T>
174where
175 T: CommonBounds,
176 T::Vec: Send,
177{
178 pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
188 StridedMapMut {
189 base: ParStridedMut::new(res_tensor),
190 phantom: std::marker::PhantomData,
191 }
192 }
193
194 pub fn zip<C>(self, other: C) -> StridedZip<'a, Self, C>
210 where
211 C: 'a + IterGetSet,
212 <C as IterGetSet>::Item: Send,
213 {
214 StridedZip::new(self, other)
215 }
216}
217
218impl<'a, T> StridedIterator for StridedMapMut<'a, T> where T: 'a + CommonBounds {}
219
220impl<'a, T: 'a + CommonBounds> IterGetSet for StridedMapMut<'a, T>
221where
222 T::Vec: Send,
223{
224 type Item = &'a mut T;
225
226 fn set_end_index(&mut self, _: usize) {}
227
228 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {}
229
230 fn set_strides(&mut self, strides: Strides) {
231 self.base.set_strides(strides);
232 }
233
234 fn set_shape(&mut self, shape: Shape) {
235 self.base.set_shape(shape);
236 }
237
238 fn set_prg(&mut self, prg: Vec<i64>) {
239 self.base.set_prg(prg);
240 }
241
242 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
243 self.base.intervals()
244 }
245
246 fn strides(&self) -> &Strides {
247 self.base.strides()
248 }
249
250 fn shape(&self) -> &Shape {
251 self.base.shape()
252 }
253
254 fn layout(&self) -> &hpt_common::layout::layout::Layout {
255 self.base.layout()
256 }
257
258 fn broadcast_set_strides(&mut self, shape: &Shape) {
259 self.base.broadcast_set_strides(shape);
260 }
261
262 fn outer_loop_size(&self) -> usize {
263 self.base.outer_loop_size()
264 }
265
266 fn inner_loop_size(&self) -> usize {
267 self.base.inner_loop_size()
268 }
269
270 fn next(&mut self) {
271 self.base.next();
272 }
273
274 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
275 self.base.inner_loop_next(index)
276 }
277}