1use core::marker::PhantomData;
2use core::mem::size_of;
3use core::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
4
5use crate::field::packed::PackedField;
6
7#[derive(Debug, Copy, Clone)]
26pub struct PackedStridedView<'a, P: PackedField> {
27 start_ptr: *const P::Scalar,
35 length: usize,
38 stride: usize,
40 _phantom: PhantomData<&'a [P::Scalar]>,
41}
42
43impl<'a, P: PackedField> PackedStridedView<'a, P> {
44 #[inline]
49 pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self {
50 assert!(
51 stride >= P::WIDTH,
52 "stride (got {}) must be at least P::WIDTH ({})",
53 stride,
54 P::WIDTH
55 );
56 assert_eq!(
57 data.len() % stride,
58 0,
59 "data.len() ({}) must be a multiple of stride (got {})",
60 data.len(),
61 stride
62 );
63
64 assert!(
67 offset + P::WIDTH <= stride,
68 "offset (got {}) + P::WIDTH ({}) cannot be greater than stride (got {})",
69 offset,
70 P::WIDTH,
71 stride
72 );
73
74 let start_ptr = data.as_ptr().wrapping_add(offset);
77
78 Self {
79 start_ptr,
80 length: data.len() / stride,
81 stride,
82 _phantom: PhantomData,
83 }
84 }
85
86 #[inline]
87 pub const fn get(&self, index: usize) -> Option<&'a P> {
88 if index < self.length {
89 let res_ptr = unsafe { self.start_ptr.add(index * self.stride) }.cast();
91 Some(unsafe { &*res_ptr })
93 } else {
94 None
95 }
96 }
97
98 #[inline]
100 pub fn view<I>(&self, index: I) -> Self
101 where
102 Self: Viewable<I, View = Self>,
103 {
104 <Self as Viewable<I>>::view(self, index)
109 }
110
111 #[inline]
112 pub const fn iter(&self) -> PackedStridedViewIter<'a, P> {
113 PackedStridedViewIter::new(
114 self.start_ptr,
115 self.start_ptr.wrapping_add(self.length * self.stride),
118 self.stride,
119 )
120 }
121
122 #[inline]
123 pub const fn len(&self) -> usize {
124 self.length
125 }
126
127 #[inline]
128 pub const fn is_empty(&self) -> bool {
129 self.len() == 0
130 }
131}
132
133impl<P: PackedField> Index<usize> for PackedStridedView<'_, P> {
134 type Output = P;
135 #[inline]
136 fn index(&self, index: usize) -> &Self::Output {
137 self.get(index)
138 .expect("invalid memory access in PackedStridedView")
139 }
140}
141
142impl<'a, P: PackedField> IntoIterator for PackedStridedView<'a, P> {
143 type Item = &'a P;
144 type IntoIter = PackedStridedViewIter<'a, P>;
145 fn into_iter(self) -> Self::IntoIter {
146 self.iter()
147 }
148}
149
150#[derive(Clone, Copy, Debug)]
151pub struct TryFromPackedStridedViewError;
152
153impl<P: PackedField, const N: usize> TryInto<[P; N]> for PackedStridedView<'_, P> {
154 type Error = TryFromPackedStridedViewError;
155 fn try_into(self) -> Result<[P; N], Self::Error> {
156 if N == self.len() {
157 let mut res = [P::default(); N];
158 for i in 0..N {
159 res[i] = *self.get(i).unwrap();
160 }
161 Ok(res)
162 } else {
163 Err(TryFromPackedStridedViewError)
164 }
165 }
166}
167
168#[derive(Clone, Debug)]
170pub struct PackedStridedViewIter<'a, P: PackedField> {
171 start: *const P::Scalar,
180 end: *const P::Scalar,
181 stride: usize,
182 _phantom: PhantomData<&'a [P::Scalar]>,
183}
184
185impl<P: PackedField> PackedStridedViewIter<'_, P> {
186 pub(self) const fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self {
187 Self {
188 start,
189 end,
190 stride,
191 _phantom: PhantomData,
192 }
193 }
194}
195
196impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> {
197 type Item = &'a P;
198 fn next(&mut self) -> Option<Self::Item> {
199 debug_assert_eq!(
200 (self.end as usize).wrapping_sub(self.start as usize)
201 % (self.stride * size_of::<P::Scalar>()),
202 0,
203 "start and end pointers should be separated by a multiple of stride"
204 );
205
206 if !core::ptr::eq(self.start, self.end) {
207 let res = unsafe { &*self.start.cast() };
208 self.start = self.start.wrapping_add(self.stride);
211 Some(res)
212 } else {
213 None
214 }
215 }
216}
217
218impl<P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'_, P> {
219 fn next_back(&mut self) -> Option<Self::Item> {
220 debug_assert_eq!(
221 (self.end as usize).wrapping_sub(self.start as usize)
222 % (self.stride * size_of::<P::Scalar>()),
223 0,
224 "start and end pointers should be separated by a multiple of stride"
225 );
226
227 if !core::ptr::eq(self.start, self.end) {
228 self.end = self.end.wrapping_sub(self.stride);
231 Some(unsafe { &*self.end.cast() })
232 } else {
233 None
234 }
235 }
236}
237
238pub trait Viewable<F> {
239 type View;
241 fn view(&self, index: F) -> Self::View;
242}
243
244impl<P: PackedField> Viewable<Range<usize>> for PackedStridedView<'_, P> {
245 type View = Self;
246 fn view(&self, range: Range<usize>) -> Self::View {
247 assert!(range.start <= self.len(), "Invalid access");
248 assert!(range.end <= self.len(), "Invalid access");
249 Self {
250 start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
253 length: range.end - range.start,
254 stride: self.stride,
255 _phantom: PhantomData,
256 }
257 }
258}
259
260impl<P: PackedField> Viewable<RangeFrom<usize>> for PackedStridedView<'_, P> {
261 type View = Self;
262 fn view(&self, range: RangeFrom<usize>) -> Self::View {
263 assert!(range.start <= self.len(), "Invalid access");
264 Self {
265 start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
268 length: self.len() - range.start,
269 stride: self.stride,
270 _phantom: PhantomData,
271 }
272 }
273}
274
275impl<P: PackedField> Viewable<RangeFull> for PackedStridedView<'_, P> {
276 type View = Self;
277 fn view(&self, _range: RangeFull) -> Self::View {
278 *self
279 }
280}
281
282impl<P: PackedField> Viewable<RangeInclusive<usize>> for PackedStridedView<'_, P> {
283 type View = Self;
284 fn view(&self, range: RangeInclusive<usize>) -> Self::View {
285 assert!(*range.start() <= self.len(), "Invalid access");
286 assert!(*range.end() < self.len(), "Invalid access");
287 Self {
288 start_ptr: self.start_ptr.wrapping_add(self.stride * range.start()),
291 length: range.end() - range.start() + 1,
292 stride: self.stride,
293 _phantom: PhantomData,
294 }
295 }
296}
297
298impl<P: PackedField> Viewable<RangeTo<usize>> for PackedStridedView<'_, P> {
299 type View = Self;
300 fn view(&self, range: RangeTo<usize>) -> Self::View {
301 assert!(range.end <= self.len(), "Invalid access");
302 Self {
303 start_ptr: self.start_ptr,
304 length: range.end,
305 stride: self.stride,
306 _phantom: PhantomData,
307 }
308 }
309}
310
311impl<P: PackedField> Viewable<RangeToInclusive<usize>> for PackedStridedView<'_, P> {
312 type View = Self;
313 fn view(&self, range: RangeToInclusive<usize>) -> Self::View {
314 assert!(range.end < self.len(), "Invalid access");
315 Self {
316 start_ptr: self.start_ptr,
317 length: range.end + 1,
318 stride: self.stride,
319 _phantom: PhantomData,
320 }
321 }
322}