1use ark_serialize::{Read, Write};
2use ark_std::ops::{Index, IndexMut};
3
4use crate::BigInt;
5
6#[macro_export]
22macro_rules! const_for {
23 (($i:ident in $start:tt..$end:tt) $code:expr ) => {{
24 let mut $i = $start;
25 while $i < $end {
26 $code
27 $i += 1;
28 }
29 }};
30}
31
32#[derive(Copy, Clone)]
35#[repr(C, align(8))]
36pub(super) struct MulBuffer<const N: usize> {
37 pub(super) b0: [u64; N],
38 pub(super) b1: [u64; N],
39}
40
41impl<const N: usize> MulBuffer<N> {
42 const fn new(b0: [u64; N], b1: [u64; N]) -> Self {
43 Self { b0, b1 }
44 }
45
46 pub(super) const fn zeroed() -> Self {
47 let b = [0u64; N];
48 Self::new(b, b)
49 }
50
51 #[inline(always)]
52 pub(super) const fn get(&self, index: usize) -> &u64 {
53 if index < N {
54 &self.b0[index]
55 } else {
56 &self.b1[index - N]
57 }
58 }
59
60 #[inline(always)]
61 pub(super) fn get_mut(&mut self, index: usize) -> &mut u64 {
62 if index < N {
63 &mut self.b0[index]
64 } else {
65 &mut self.b1[index - N]
66 }
67 }
68}
69
70impl<const N: usize> Index<usize> for MulBuffer<N> {
71 type Output = u64;
72 #[inline(always)]
73 fn index(&self, index: usize) -> &Self::Output {
74 self.get(index)
75 }
76}
77
78impl<const N: usize> IndexMut<usize> for MulBuffer<N> {
79 #[inline(always)]
80 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81 self.get_mut(index)
82 }
83}
84
85#[derive(Copy, Clone)]
88#[repr(C, align(1))]
89pub(super) struct SerBuffer<const N: usize> {
90 pub(super) buffers: [[u8; 8]; N],
91 pub(super) last: u8,
92}
93
94impl<const N: usize> SerBuffer<N> {
95 pub(super) const fn zeroed() -> Self {
96 Self {
97 buffers: [[0u8; 8]; N],
98 last: 0u8,
99 }
100 }
101
102 #[inline(always)]
103 pub(super) const fn get(&self, index: usize) -> &u8 {
104 if index == 8 * N {
105 &self.last
106 } else {
107 let part = index / 8;
108 let in_buffer_index = index % 8;
109 &self.buffers[part][in_buffer_index]
110 }
111 }
112
113 #[inline(always)]
114 pub(super) fn get_mut(&mut self, index: usize) -> &mut u8 {
115 if index == 8 * N {
116 &mut self.last
117 } else {
118 let part = index / 8;
119 let in_buffer_index = index % 8;
120 &mut self.buffers[part][in_buffer_index]
121 }
122 }
123
124 #[allow(unsafe_code)]
125 pub(super) const fn as_slice(&self) -> &[u8] {
126 unsafe { ark_std::slice::from_raw_parts((self as *const Self) as *const u8, 8 * N + 1) }
127 }
128
129 #[inline(always)]
130 pub(super) fn last_n_plus_1_bytes_mut(&mut self) -> impl Iterator<Item = &mut u8> {
131 self.buffers[N - 1]
132 .iter_mut()
133 .chain(ark_std::iter::once(&mut self.last))
134 }
135
136 #[inline(always)]
137 pub(super) fn copy_from_u8_slice(&mut self, other: &[u8]) {
138 other.chunks(8).enumerate().for_each(|(i, chunk)| {
139 if i < N {
140 self.buffers[i][..chunk.len()].copy_from_slice(chunk);
141 } else {
142 self.last = chunk[0]
143 }
144 });
145 }
146
147 #[inline(always)]
148 pub(super) fn copy_from_u64_slice(&mut self, other: &[u64; N]) {
149 other
150 .iter()
151 .zip(&mut self.buffers)
152 .for_each(|(other, this)| *this = other.to_le_bytes());
153 }
154
155 #[inline(always)]
156 pub(super) fn to_bigint(self) -> BigInt<N> {
157 let mut self_integer = BigInt::from(0u64);
158 self_integer
159 .0
160 .iter_mut()
161 .zip(self.buffers)
162 .for_each(|(other, this)| *other = u64::from_le_bytes(this));
163 self_integer
164 }
165
166 #[inline(always)]
167 pub(super) fn write_up_to(
170 &self,
171 mut other: impl Write,
172 num_bytes: usize,
173 ) -> ark_std::io::Result<()> {
174 debug_assert!(num_bytes <= 8 * N + 1, "index too large");
175 debug_assert!(num_bytes > 8 * (N - 1), "index too small");
176 for i in 0..(N - 1) {
178 other.write_all(&self.buffers[i])?;
179 }
180 let remaining_bytes = num_bytes - (8 * (N - 1));
183 let write_last_byte = remaining_bytes > 8;
184 let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
185 other.write_all(&self.buffers[N - 1][..num_last_limb_bytes])?;
186 if write_last_byte {
187 other.write_all(&[self.last])?;
188 }
189 Ok(())
190 }
191
192 #[inline(always)]
193 pub(super) fn read_exact_up_to(
196 &mut self,
197 mut other: impl Read,
198 num_bytes: usize,
199 ) -> ark_std::io::Result<()> {
200 debug_assert!(num_bytes <= 8 * N + 1, "index too large");
201 debug_assert!(num_bytes > 8 * (N - 1), "index too small");
202 for i in 0..(N - 1) {
204 other.read_exact(&mut self.buffers[i])?;
205 }
206 let remaining_bytes = num_bytes - (8 * (N - 1));
209 let write_last_byte = remaining_bytes > 8;
210 let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
211 other.read_exact(&mut self.buffers[N - 1][..num_last_limb_bytes])?;
212 if write_last_byte {
213 let mut last = [0u8; 1];
214 other.read_exact(&mut last)?;
215 self.last = last[0];
216 }
217 Ok(())
218 }
219}
220
221impl<const N: usize> Index<usize> for SerBuffer<N> {
222 type Output = u8;
223 #[inline(always)]
224 fn index(&self, index: usize) -> &Self::Output {
225 self.get(index)
226 }
227}
228
229impl<const N: usize> IndexMut<usize> for SerBuffer<N> {
230 #[inline(always)]
231 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
232 self.get_mut(index)
233 }
234}
235
236pub(super) struct RBuffer<const N: usize>(pub [u64; N], pub u64);
237
238impl<const N: usize> RBuffer<N> {
239 pub(super) const fn num_bits(&self) -> u32 {
241 (N * 64) as u32 + (64 - self.1.leading_zeros())
242 }
243
244 pub(super) const fn get_bit(&self, i: usize) -> bool {
247 let d = i / 64;
248 let b = i % 64;
249 if d == N {
250 (self.1 >> b) & 1 == 1
251 } else {
252 (self.0[d] >> b) & 1 == 1
253 }
254 }
255}
256
257pub(super) struct R2Buffer<const N: usize>(pub [u64; N], pub [u64; N], pub u64);
258
259impl<const N: usize> R2Buffer<N> {
260 pub(super) const fn num_bits(&self) -> u32 {
262 ((2 * N) * 64) as u32 + (64 - self.2.leading_zeros())
263 }
264
265 pub(super) const fn get_bit(&self, i: usize) -> bool {
268 let d = i / 64;
269 let b = i % 64;
270 if d == 2 * N {
271 (self.2 >> b) & 1 == 1
272 } else if d >= N {
273 (self.1[d - N] >> b) & 1 == 1
274 } else {
275 (self.0[d] >> b) & 1 == 1
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_const_for_macro() {
286 let mut array = [0usize; 4];
287 const_for!((i in 0..(array.len())) {
288 array[i] = i;
289 });
290 assert_eq!(array, [0, 1, 2, 3]);
291 }
292
293 #[test]
294 fn test_mul_buffer_new_and_get() {
295 type Buf = MulBuffer<4>;
296 let buf = Buf::new([1u64, 2u64, 3u64, 4u64], [5u64, 6u64, 7u64, 8u64]);
297
298 assert_eq!(*buf.get(0), 1);
299 assert_eq!(*buf.get(3), 4);
300 assert_eq!(*buf.get(4), 5);
301 assert_eq!(*buf.get(7), 8);
302 }
303
304 #[test]
305 fn test_mul_buffer_get_mut() {
306 type Buf = MulBuffer<4>;
307 let mut buf = Buf::zeroed();
308 *buf.get_mut(2) = 42;
309 assert_eq!(buf.b0[2], 42);
310
311 *buf.get_mut(5) = 99;
312 assert_eq!(buf.b1[1], 99);
313 }
314
315 #[test]
316 fn test_ser_buffer_zeroed_and_get() {
317 type Ser = SerBuffer<2>;
318 let buf = Ser::zeroed();
319 assert_eq!(*buf.get(0), 0);
320 assert_eq!(*buf.get(15), 0);
321 assert_eq!(*buf.get(16), 0); }
323
324 #[test]
325 fn test_ser_buffer_copy_from_u8_slice() {
326 type Ser = SerBuffer<2>;
327 let mut buf = Ser::zeroed();
328 let data: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
329 buf.copy_from_u8_slice(data);
330
331 assert_eq!(buf.buffers[0], [1, 2, 3, 4, 5, 6, 7, 8]);
332 assert_eq!(buf.buffers[1], [9, 10, 11, 12, 13, 14, 15, 16]);
333 assert_eq!(buf.last, 17);
334 }
335
336 #[test]
337 fn test_ser_buffer_copy_from_u64_slice() {
338 type Ser = SerBuffer<2>;
339 let mut buf = Ser::zeroed();
340 let data: &[u64; 2] = &[0x123456789ABCDEF0, 0x0FEDCBA987654321];
341 buf.copy_from_u64_slice(data);
342
343 assert_eq!(buf.buffers[0], 0x123456789ABCDEF0u64.to_le_bytes());
344 assert_eq!(buf.buffers[1], 0x0FEDCBA987654321u64.to_le_bytes());
345 }
346
347 #[test]
348 fn test_rbuffer_get_bit() {
349 let buf = RBuffer([0x0, 0x8000000000000000], 0x1); assert!(!buf.get_bit(63)); assert!(buf.get_bit(127)); assert!(buf.get_bit(128)); }
356
357 #[test]
358 fn test_ser_buffer_write_and_read() {
359 type Ser = SerBuffer<2>;
360 let buf = Ser::zeroed();
361 let mut data = ark_std::vec::Vec::new();
362 buf.write_up_to(&mut data, 16)
363 .expect("Failed to write buffer");
364
365 let mut new_buf = Ser::zeroed();
366 new_buf
367 .read_exact_up_to(&data[..], 16)
368 .expect("Failed to read buffer");
369
370 assert_eq!(buf.buffers, new_buf.buffers);
371 assert_eq!(buf.last, new_buf.last);
372 }
373
374 #[test]
375 fn test_mul_buffer_correctness() {
376 type Buf = MulBuffer<10>;
377 let temp = Buf::new([10u64; 10], [20u64; 10]);
378
379 for i in 0..20 {
380 if i < 10 {
381 assert_eq!(temp[i], 10);
382 } else {
383 assert_eq!(temp[i], 20);
384 }
385 }
386 }
387
388 #[test]
389 #[should_panic]
390 fn test_mul_buffer_soundness() {
391 type Buf = MulBuffer<10>;
392 let temp = Buf::new([10u64; 10], [10u64; 10]);
393
394 for i in 20..21 {
395 assert_eq!(temp[i], 10);
397 }
398 }
399}