Skip to main content

ark_ff/
const_helpers.rs

1use ark_serialize::{Read, Write};
2use ark_std::ops::{Index, IndexMut};
3
4use crate::BigInt;
5
6/// A helper macro for emulating `for` loops in a `const` context.
7/// # Usage
8/// ```rust
9/// # use ark_ff::const_for;
10/// const fn for_in_const() {
11///     let mut array = [0usize; 4];
12///     const_for!((i in 0..(array.len())) { // We need to wrap the `array.len()` in parenthesis.
13///         array[i] = i;
14///     });
15///     assert!(array[0] == 0);
16///     assert!(array[1] == 1);
17///     assert!(array[2] == 2);
18///     assert!(array[3] == 3);
19/// }
20/// ```
21#[macro_export]
22macro_rules! const_for {
23    (($i:ident in $start:tt..$end:tt)  $code:expr ) => {{
24        let mut $i = $start;
25        while $i < $end {
26            $code
27            $i += 1;
28        }
29    }};
30}
31
32/// A buffer to hold values of size 2 * N. This is mostly
33/// a hack that's necessary until `generic_const_exprs` is stable.
34#[derive(Copy, Clone)]
35#[repr(C, align(8))]
36pub(super) struct MulBuffer<const N: usize> {
37    pub(super) b0: [u64; N],
38    pub(super) b1: [u64; N],
39}
40
41impl<const N: usize> MulBuffer<N> {
42    const fn new(b0: [u64; N], b1: [u64; N]) -> Self {
43        Self { b0, b1 }
44    }
45
46    pub(super) const fn zeroed() -> Self {
47        let b = [0u64; N];
48        Self::new(b, b)
49    }
50
51    #[inline(always)]
52    pub(super) const fn get(&self, index: usize) -> &u64 {
53        if index < N {
54            &self.b0[index]
55        } else {
56            &self.b1[index - N]
57        }
58    }
59
60    #[inline(always)]
61    pub(super) fn get_mut(&mut self, index: usize) -> &mut u64 {
62        if index < N {
63            &mut self.b0[index]
64        } else {
65            &mut self.b1[index - N]
66        }
67    }
68}
69
70impl<const N: usize> Index<usize> for MulBuffer<N> {
71    type Output = u64;
72    #[inline(always)]
73    fn index(&self, index: usize) -> &Self::Output {
74        self.get(index)
75    }
76}
77
78impl<const N: usize> IndexMut<usize> for MulBuffer<N> {
79    #[inline(always)]
80    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81        self.get_mut(index)
82    }
83}
84
85/// A buffer to hold values of size 8 * N + 1 bytes. This is mostly
86/// a hack that's necessary until `generic_const_exprs` is stable.
87#[derive(Copy, Clone)]
88#[repr(C, align(1))]
89pub(super) struct SerBuffer<const N: usize> {
90    pub(super) buffers: [[u8; 8]; N],
91    pub(super) last: u8,
92}
93
94impl<const N: usize> SerBuffer<N> {
95    pub(super) const fn zeroed() -> Self {
96        Self {
97            buffers: [[0u8; 8]; N],
98            last: 0u8,
99        }
100    }
101
102    #[inline(always)]
103    pub(super) const fn get(&self, index: usize) -> &u8 {
104        if index == 8 * N {
105            &self.last
106        } else {
107            let part = index / 8;
108            let in_buffer_index = index % 8;
109            &self.buffers[part][in_buffer_index]
110        }
111    }
112
113    #[inline(always)]
114    pub(super) fn get_mut(&mut self, index: usize) -> &mut u8 {
115        if index == 8 * N {
116            &mut self.last
117        } else {
118            let part = index / 8;
119            let in_buffer_index = index % 8;
120            &mut self.buffers[part][in_buffer_index]
121        }
122    }
123
124    #[allow(unsafe_code)]
125    pub(super) const fn as_slice(&self) -> &[u8] {
126        unsafe { ark_std::slice::from_raw_parts((self as *const Self) as *const u8, 8 * N + 1) }
127    }
128
129    #[inline(always)]
130    pub(super) fn last_n_plus_1_bytes_mut(&mut self) -> impl Iterator<Item = &mut u8> {
131        self.buffers[N - 1]
132            .iter_mut()
133            .chain(ark_std::iter::once(&mut self.last))
134    }
135
136    #[inline(always)]
137    pub(super) fn copy_from_u8_slice(&mut self, other: &[u8]) {
138        other.chunks(8).enumerate().for_each(|(i, chunk)| {
139            if i < N {
140                self.buffers[i][..chunk.len()].copy_from_slice(chunk);
141            } else {
142                self.last = chunk[0]
143            }
144        });
145    }
146
147    #[inline(always)]
148    pub(super) fn copy_from_u64_slice(&mut self, other: &[u64; N]) {
149        other
150            .iter()
151            .zip(&mut self.buffers)
152            .for_each(|(other, this)| *this = other.to_le_bytes());
153    }
154
155    #[inline(always)]
156    pub(super) fn to_bigint(self) -> BigInt<N> {
157        let mut self_integer = BigInt::from(0u64);
158        self_integer
159            .0
160            .iter_mut()
161            .zip(self.buffers)
162            .for_each(|(other, this)| *other = u64::from_le_bytes(this));
163        self_integer
164    }
165
166    #[inline(always)]
167    /// Write up to `num_bytes` bytes from `self` to `other`.
168    /// `num_bytes` is allowed to range from `8 * (N - 1) + 1` to `8 * N + 1`.
169    pub(super) fn write_up_to(
170        &self,
171        mut other: impl Write,
172        num_bytes: usize,
173    ) -> ark_std::io::Result<()> {
174        debug_assert!(num_bytes <= 8 * N + 1, "index too large");
175        debug_assert!(num_bytes > 8 * (N - 1), "index too small");
176        // unconditionally write first `N - 1` limbs.
177        for i in 0..(N - 1) {
178            other.write_all(&self.buffers[i])?;
179        }
180        // for the `N`-th limb, depending on `index`, we can write anywhere from
181        // 1 to all bytes.
182        let remaining_bytes = num_bytes - (8 * (N - 1));
183        let write_last_byte = remaining_bytes > 8;
184        let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
185        other.write_all(&self.buffers[N - 1][..num_last_limb_bytes])?;
186        if write_last_byte {
187            other.write_all(&[self.last])?;
188        }
189        Ok(())
190    }
191
192    #[inline(always)]
193    /// Read up to `num_bytes` bytes from `other` to `self`.
194    /// `num_bytes` is allowed to range from `8 * (N - 1)` to `8 * N + 1`.
195    pub(super) fn read_exact_up_to(
196        &mut self,
197        mut other: impl Read,
198        num_bytes: usize,
199    ) -> ark_std::io::Result<()> {
200        debug_assert!(num_bytes <= 8 * N + 1, "index too large");
201        debug_assert!(num_bytes > 8 * (N - 1), "index too small");
202        // unconditionally write first `N - 1` limbs.
203        for i in 0..(N - 1) {
204            other.read_exact(&mut self.buffers[i])?;
205        }
206        // for the `N`-th limb, depending on `index`, we can write anywhere from
207        // 1 to all bytes.
208        let remaining_bytes = num_bytes - (8 * (N - 1));
209        let write_last_byte = remaining_bytes > 8;
210        let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
211        other.read_exact(&mut self.buffers[N - 1][..num_last_limb_bytes])?;
212        if write_last_byte {
213            let mut last = [0u8; 1];
214            other.read_exact(&mut last)?;
215            self.last = last[0];
216        }
217        Ok(())
218    }
219}
220
221impl<const N: usize> Index<usize> for SerBuffer<N> {
222    type Output = u8;
223    #[inline(always)]
224    fn index(&self, index: usize) -> &Self::Output {
225        self.get(index)
226    }
227}
228
229impl<const N: usize> IndexMut<usize> for SerBuffer<N> {
230    #[inline(always)]
231    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
232        self.get_mut(index)
233    }
234}
235
236pub(super) struct RBuffer<const N: usize>(pub [u64; N], pub u64);
237
238impl<const N: usize> RBuffer<N> {
239    /// Find the number of bits in the binary decomposition of `self`.
240    pub(super) const fn num_bits(&self) -> u32 {
241        (N * 64) as u32 + (64 - self.1.leading_zeros())
242    }
243
244    /// Returns the `i`-th bit where bit 0 is the least significant one.
245    /// In other words, the bit with weight `2^i`.
246    pub(super) const fn get_bit(&self, i: usize) -> bool {
247        let d = i / 64;
248        let b = i % 64;
249        if d == N {
250            (self.1 >> b) & 1 == 1
251        } else {
252            (self.0[d] >> b) & 1 == 1
253        }
254    }
255}
256
257pub(super) struct R2Buffer<const N: usize>(pub [u64; N], pub [u64; N], pub u64);
258
259impl<const N: usize> R2Buffer<N> {
260    /// Find the number of bits in the binary decomposition of `self`.
261    pub(super) const fn num_bits(&self) -> u32 {
262        ((2 * N) * 64) as u32 + (64 - self.2.leading_zeros())
263    }
264
265    /// Returns the `i`-th bit where bit 0 is the least significant one.
266    /// In other words, the bit with weight `2^i`.
267    pub(super) const fn get_bit(&self, i: usize) -> bool {
268        let d = i / 64;
269        let b = i % 64;
270        if d == 2 * N {
271            (self.2 >> b) & 1 == 1
272        } else if d >= N {
273            (self.1[d - N] >> b) & 1 == 1
274        } else {
275            (self.0[d] >> b) & 1 == 1
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_const_for_macro() {
286        let mut array = [0usize; 4];
287        const_for!((i in 0..(array.len())) {
288            array[i] = i;
289        });
290        assert_eq!(array, [0, 1, 2, 3]);
291    }
292
293    #[test]
294    fn test_mul_buffer_new_and_get() {
295        type Buf = MulBuffer<4>;
296        let buf = Buf::new([1u64, 2u64, 3u64, 4u64], [5u64, 6u64, 7u64, 8u64]);
297
298        assert_eq!(*buf.get(0), 1);
299        assert_eq!(*buf.get(3), 4);
300        assert_eq!(*buf.get(4), 5);
301        assert_eq!(*buf.get(7), 8);
302    }
303
304    #[test]
305    fn test_mul_buffer_get_mut() {
306        type Buf = MulBuffer<4>;
307        let mut buf = Buf::zeroed();
308        *buf.get_mut(2) = 42;
309        assert_eq!(buf.b0[2], 42);
310
311        *buf.get_mut(5) = 99;
312        assert_eq!(buf.b1[1], 99);
313    }
314
315    #[test]
316    fn test_ser_buffer_zeroed_and_get() {
317        type Ser = SerBuffer<2>;
318        let buf = Ser::zeroed();
319        assert_eq!(*buf.get(0), 0);
320        assert_eq!(*buf.get(15), 0);
321        assert_eq!(*buf.get(16), 0); // Check the `last` byte
322    }
323
324    #[test]
325    fn test_ser_buffer_copy_from_u8_slice() {
326        type Ser = SerBuffer<2>;
327        let mut buf = Ser::zeroed();
328        let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
329        buf.copy_from_u8_slice(data);
330
331        assert_eq!(buf.buffers[0], [1, 2, 3, 4, 5, 6, 7, 8]);
332        assert_eq!(buf.buffers[1], [9, 10, 11, 12, 13, 14, 15, 16]);
333        assert_eq!(buf.last, 17);
334    }
335
336    #[test]
337    fn test_ser_buffer_copy_from_u64_slice() {
338        type Ser = SerBuffer<2>;
339        let mut buf = Ser::zeroed();
340        let data: &[u64; 2] = &[0x123456789ABCDEF0, 0x0FEDCBA987654321];
341        buf.copy_from_u64_slice(data);
342
343        assert_eq!(buf.buffers[0], 0x123456789ABCDEF0u64.to_le_bytes());
344        assert_eq!(buf.buffers[1], 0x0FEDCBA987654321u64.to_le_bytes());
345    }
346
347    #[test]
348    fn test_rbuffer_get_bit() {
349        // Create an instance of RBuffer
350        let buf = RBuffer([0x0, 0x8000000000000000], 0x1); // Second value has MSB set, and last has LSB set
351
352        assert!(!buf.get_bit(63)); // Check the 63rd bit of the first part
353        assert!(buf.get_bit(127)); // Check the MSB of the second part
354        assert!(buf.get_bit(128)); // Check the LSB of the third part
355    }
356
357    #[test]
358    fn test_ser_buffer_write_and_read() {
359        type Ser = SerBuffer<2>;
360        let buf = Ser::zeroed();
361        let mut data = ark_std::vec::Vec::new();
362        buf.write_up_to(&mut data, 16)
363            .expect("Failed to write buffer");
364
365        let mut new_buf = Ser::zeroed();
366        new_buf
367            .read_exact_up_to(&data[..], 16)
368            .expect("Failed to read buffer");
369
370        assert_eq!(buf.buffers, new_buf.buffers);
371        assert_eq!(buf.last, new_buf.last);
372    }
373
374    #[test]
375    fn test_mul_buffer_correctness() {
376        type Buf = MulBuffer<10>;
377        let temp = Buf::new([10u64; 10], [20u64; 10]);
378
379        for i in 0..20 {
380            if i < 10 {
381                assert_eq!(temp[i], 10);
382            } else {
383                assert_eq!(temp[i], 20);
384            }
385        }
386    }
387
388    #[test]
389    #[should_panic]
390    fn test_mul_buffer_soundness() {
391        type Buf = MulBuffer<10>;
392        let temp = Buf::new([10u64; 10], [10u64; 10]);
393
394        for i in 20..21 {
395            // indexing `temp[20]` should panic
396            assert_eq!(temp[i], 10);
397        }
398    }
399}