../../.cargo/katex-header.html

plonky2/util/
strided_view.rs

1use core::marker::PhantomData;
2use core::mem::size_of;
3use core::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
4
5use crate::field::packed::PackedField;
6
7/// Imagine a slice, but with a stride (a la a NumPy array).
8///
9/// For example, if the stride is 3,
10///     `packed_strided_view[0]` is `data[0]`,
11///     `packed_strided_view[1]` is `data[3]`,
12///     `packed_strided_view[2]` is `data[6]`,
13/// and so on. An offset may be specified. With an offset of 1, we get
14///     `packed_strided_view[0]` is `data[1]`,
15///     `packed_strided_view[1]` is `data[4]`,
16///     `packed_strided_view[2]` is `data[7]`,
17/// and so on.
18///
19/// Additionally, this view is *packed*, which means that it may yield a packing of the underlying
20/// field slice. With a packing of width 4 and a stride of 5, the accesses are
21///     `packed_strided_view[0]` is `data[0..4]`, transmuted to the packing,
22///     `packed_strided_view[1]` is `data[5..9]`, transmuted to the packing,
23///     `packed_strided_view[2]` is `data[10..14]`, transmuted to the packing,
24/// and so on.
25#[derive(Debug, Copy, Clone)]
26pub struct PackedStridedView<'a, P: PackedField> {
27    // This type has to be a struct, which means that it is not itself a reference (in the sense
28    // that a slice is a reference so we can return it from e.g. `Index::index`).
29
30    // Raw pointers rarely appear in good Rust code, but I think this is the most elegant way to
31    // implement this. The alternative would be to replace `start_ptr` and `length` with one slice
32    // (`&[P::Scalar]`). Unfortunately, with a slice, an empty view becomes an edge case that
33    // necessitates separate handling. It _could_ be done but it would also be uglier.
34    start_ptr: *const P::Scalar,
35    /// This is the total length of elements accessible through the view. In other words, valid
36    /// indices are in `0..length`.
37    length: usize,
38    /// This stride is in units of `P::Scalar` (NOT in bytes and NOT in `P`).
39    stride: usize,
40    _phantom: PhantomData<&'a [P::Scalar]>,
41}
42
43impl<'a, P: PackedField> PackedStridedView<'a, P> {
44    // `wrapping_add` is needed throughout to avoid undefined behavior. Plain `add` causes UB if
45    // '[either] the starting [or] resulting pointer [is neither] in bounds or one byte past the
46    // end of the same allocated object'; the UB results even if the pointer is not dereferenced.
47
48    #[inline]
49    pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self {
50        assert!(
51            stride >= P::WIDTH,
52            "stride (got {}) must be at least P::WIDTH ({})",
53            stride,
54            P::WIDTH
55        );
56        assert_eq!(
57            data.len() % stride,
58            0,
59            "data.len() ({}) must be a multiple of stride (got {})",
60            data.len(),
61            stride
62        );
63
64        // This requirement means that stride divides data into slices of `data.len() / stride`
65        // elements. Every access must fit entirely within one of those slices.
66        assert!(
67            offset + P::WIDTH <= stride,
68            "offset (got {}) + P::WIDTH ({}) cannot be greater than stride (got {})",
69            offset,
70            P::WIDTH,
71            stride
72        );
73
74        // See comment above. `start_ptr` will be more than one byte past the buffer if `data` has
75        // length 0 and `offset` is not 0.
76        let start_ptr = data.as_ptr().wrapping_add(offset);
77
78        Self {
79            start_ptr,
80            length: data.len() / stride,
81            stride,
82            _phantom: PhantomData,
83        }
84    }
85
86    #[inline]
87    pub const fn get(&self, index: usize) -> Option<&'a P> {
88        if index < self.length {
89            // Cast scalar pointer to vector pointer.
90            let res_ptr = unsafe { self.start_ptr.add(index * self.stride) }.cast();
91            // This transmutation is safe by the spec in `PackedField`.
92            Some(unsafe { &*res_ptr })
93        } else {
94            None
95        }
96    }
97
98    /// Take a range of `PackedStridedView` indices, as `PackedStridedView`.
99    #[inline]
100    pub fn view<I>(&self, index: I) -> Self
101    where
102        Self: Viewable<I, View = Self>,
103    {
104        // We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
105
106        // The `Viewable` trait is needed for overloading.
107        // Re-export `Viewable::view` so users don't have to import `Viewable`.
108        <Self as Viewable<I>>::view(self, index)
109    }
110
111    #[inline]
112    pub const fn iter(&self) -> PackedStridedViewIter<'a, P> {
113        PackedStridedViewIter::new(
114            self.start_ptr,
115            // See comment at the top of the `impl`. Below will point more than one byte past the
116            // end of the buffer (unless `offset` is 0) so `wrapping_add` is needed.
117            self.start_ptr.wrapping_add(self.length * self.stride),
118            self.stride,
119        )
120    }
121
122    #[inline]
123    pub const fn len(&self) -> usize {
124        self.length
125    }
126
127    #[inline]
128    pub const fn is_empty(&self) -> bool {
129        self.len() == 0
130    }
131}
132
133impl<P: PackedField> Index<usize> for PackedStridedView<'_, P> {
134    type Output = P;
135    #[inline]
136    fn index(&self, index: usize) -> &Self::Output {
137        self.get(index)
138            .expect("invalid memory access in PackedStridedView")
139    }
140}
141
142impl<'a, P: PackedField> IntoIterator for PackedStridedView<'a, P> {
143    type Item = &'a P;
144    type IntoIter = PackedStridedViewIter<'a, P>;
145    fn into_iter(self) -> Self::IntoIter {
146        self.iter()
147    }
148}
149
150#[derive(Clone, Copy, Debug)]
151pub struct TryFromPackedStridedViewError;
152
153impl<P: PackedField, const N: usize> TryInto<[P; N]> for PackedStridedView<'_, P> {
154    type Error = TryFromPackedStridedViewError;
155    fn try_into(self) -> Result<[P; N], Self::Error> {
156        if N == self.len() {
157            let mut res = [P::default(); N];
158            for i in 0..N {
159                res[i] = *self.get(i).unwrap();
160            }
161            Ok(res)
162        } else {
163            Err(TryFromPackedStridedViewError)
164        }
165    }
166}
167
168// Not deriving `Copy`. An implicit copy of an iterator is likely a bug.
169#[derive(Clone, Debug)]
170pub struct PackedStridedViewIter<'a, P: PackedField> {
171    // Again, a pair of pointers is a neater solution than a slice. `start` and `end` are always
172    // separated by a multiple of stride elements. To advance the iterator from the front, we
173    // advance `start` by `stride` elements. To advance it from the end, we subtract `stride`
174    // elements. Iteration is done when they meet.
175    // A slice cannot recreate the same pattern. The end pointer may point past the underlying
176    // buffer (this is okay as we do not dereference it in that case); it becomes valid as soon as
177    // it is decreased by `stride`. On the other hand, a slice that ends on invalid memory is
178    // instant undefined behavior.
179    start: *const P::Scalar,
180    end: *const P::Scalar,
181    stride: usize,
182    _phantom: PhantomData<&'a [P::Scalar]>,
183}
184
185impl<P: PackedField> PackedStridedViewIter<'_, P> {
186    pub(self) const fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self {
187        Self {
188            start,
189            end,
190            stride,
191            _phantom: PhantomData,
192        }
193    }
194}
195
196impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> {
197    type Item = &'a P;
198    fn next(&mut self) -> Option<Self::Item> {
199        debug_assert_eq!(
200            (self.end as usize).wrapping_sub(self.start as usize)
201                % (self.stride * size_of::<P::Scalar>()),
202            0,
203            "start and end pointers should be separated by a multiple of stride"
204        );
205
206        if !core::ptr::eq(self.start, self.end) {
207            let res = unsafe { &*self.start.cast() };
208            // See comment in `PackedStridedView`. Below will point more than one byte past the end
209            // of the buffer if the offset is not 0 and we've reached the end.
210            self.start = self.start.wrapping_add(self.stride);
211            Some(res)
212        } else {
213            None
214        }
215    }
216}
217
218impl<P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'_, P> {
219    fn next_back(&mut self) -> Option<Self::Item> {
220        debug_assert_eq!(
221            (self.end as usize).wrapping_sub(self.start as usize)
222                % (self.stride * size_of::<P::Scalar>()),
223            0,
224            "start and end pointers should be separated by a multiple of stride"
225        );
226
227        if !core::ptr::eq(self.start, self.end) {
228            // See comment in `PackedStridedView`. `self.end` starts off pointing more than one byte
229            // past the end of the buffer unless `offset` is 0.
230            self.end = self.end.wrapping_sub(self.stride);
231            Some(unsafe { &*self.end.cast() })
232        } else {
233            None
234        }
235    }
236}
237
238pub trait Viewable<F> {
239    // We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
240    type View;
241    fn view(&self, index: F) -> Self::View;
242}
243
244impl<P: PackedField> Viewable<Range<usize>> for PackedStridedView<'_, P> {
245    type View = Self;
246    fn view(&self, range: Range<usize>) -> Self::View {
247        assert!(range.start <= self.len(), "Invalid access");
248        assert!(range.end <= self.len(), "Invalid access");
249        Self {
250            // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
251            // past the end of the buffer if the offset is not 0 and the buffer has length 0.
252            start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
253            length: range.end - range.start,
254            stride: self.stride,
255            _phantom: PhantomData,
256        }
257    }
258}
259
260impl<P: PackedField> Viewable<RangeFrom<usize>> for PackedStridedView<'_, P> {
261    type View = Self;
262    fn view(&self, range: RangeFrom<usize>) -> Self::View {
263        assert!(range.start <= self.len(), "Invalid access");
264        Self {
265            // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
266            // past the end of the buffer if the offset is not 0 and the buffer has length 0.
267            start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
268            length: self.len() - range.start,
269            stride: self.stride,
270            _phantom: PhantomData,
271        }
272    }
273}
274
275impl<P: PackedField> Viewable<RangeFull> for PackedStridedView<'_, P> {
276    type View = Self;
277    fn view(&self, _range: RangeFull) -> Self::View {
278        *self
279    }
280}
281
282impl<P: PackedField> Viewable<RangeInclusive<usize>> for PackedStridedView<'_, P> {
283    type View = Self;
284    fn view(&self, range: RangeInclusive<usize>) -> Self::View {
285        assert!(*range.start() <= self.len(), "Invalid access");
286        assert!(*range.end() < self.len(), "Invalid access");
287        Self {
288            // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
289            // past the end of the buffer if the offset is not 0 and the buffer has length 0.
290            start_ptr: self.start_ptr.wrapping_add(self.stride * range.start()),
291            length: range.end() - range.start() + 1,
292            stride: self.stride,
293            _phantom: PhantomData,
294        }
295    }
296}
297
298impl<P: PackedField> Viewable<RangeTo<usize>> for PackedStridedView<'_, P> {
299    type View = Self;
300    fn view(&self, range: RangeTo<usize>) -> Self::View {
301        assert!(range.end <= self.len(), "Invalid access");
302        Self {
303            start_ptr: self.start_ptr,
304            length: range.end,
305            stride: self.stride,
306            _phantom: PhantomData,
307        }
308    }
309}
310
311impl<P: PackedField> Viewable<RangeToInclusive<usize>> for PackedStridedView<'_, P> {
312    type View = Self;
313    fn view(&self, range: RangeToInclusive<usize>) -> Self::View {
314        assert!(range.end < self.len(), "Invalid access");
315        Self {
316            start_ptr: self.start_ptr,
317            length: range.end + 1,
318            stride: self.stride,
319            _phantom: PhantomData,
320        }
321    }
322}