cyclone_msm/
app.rs

1//! Host-side app to interact with FPGA app.
2use core::iter;
3
4use ark_bls12_377::{Fq, Fr, G1Affine, G1TEProjective};
5use ark_std::Zero;
6
7use fpga::{null::Backoff as NullBackoff, Flush as _, ReadWrite as _, Streamable as _, Write as _};
8
9#[cfg(not(feature = "hw"))]
10pub use fpga::Null as Fpga;
11#[cfg(feature = "hw")]
12pub use fpga::F1 as Fpga;
13
14use crate::{
15    bls12_377::{into_weierstrass, G1PTEAffine},
16    precompute::{limb_carries, single_digit_carry},
17    timing::timed,
18    App, Command, G1Projective, Packet, Scalar,
19};
20
21const DDR_READ_LEN: u32 = 64;
22
23const NUM_BUCKETS: u32 = 1 << 15;
24const FIRST_BUCKET: u32 = 0;
25const LAST_BUCKET: u32 = NUM_BUCKETS - 1;
26
27const BACKOFF_THRESHOLD: u32 = 64;
28const SET_POINTS_FLUSH_EVERY: usize = 1024;
29const SET_DIGITS_FLUSH_BACKOFF_EVERY: usize = 512;
30
31type FpgaStream<'a, B> = fpga::Stream<'a, Packet, Fpga, B>;
32
33fn shl_assign(point: &mut G1TEProjective, c: usize) {
34    use ark_ec::Group as _;
35    (0..c).for_each(|_| {
36        point.double_in_place();
37    })
38}
39
40#[repr(usize)]
41#[derive(Copy, Clone, Debug, Eq, PartialEq)]
42/// Top-level commands of the FPGA App's interface.
43/// Of more interest are the subcommands of `Stream::Msm` in [`Command`].
44pub enum Stream {
45    SetX = 1 << 26,
46    SetY = 2 << 26,
47    SetKT = 3 << 26,
48    // must start with Command::Start, then packets of Command::SetDigit
49    Msm = 4 << 26,
50    SetZero = 5 << 26,
51}
52
53impl Command {
54    #[inline(always)]
55    pub fn set_digit(digit: i16) -> u64 {
56        Command::SetDigit as u64 | (digit as u16 as u64) << 14
57    }
58}
59
60#[repr(u32)]
61#[derive(Copy, Clone, Debug, Eq, PartialEq)]
62pub enum WriteRegister {
63    // parametrised read registers Statistic, X, Y, Z need a preceding query with the parameter
64    Query = 0x10,
65    DdrReadLen = 0x11,
66    MsmLength = 0x20,
67    LastBucket = 0x21,
68    FirstBucket = 0x22,
69}
70
71#[repr(u32)]
72#[derive(Copy, Clone, Debug, Eq, PartialEq)]
73pub enum ReadRegister {
74    Statistic = 0x20,
75    DigitsQueue = 0x21,
76    Aggregated = 0x30,
77    X = 0x31,
78    Y = 0x32,
79    Z = 0x33,
80    T = 0x34,
81}
82
83#[repr(u32)]
84// TODO: double check these are named correctly
85pub enum Statistic {
86    DroppedCommands = 0,
87    DdrReadMiss = 1,
88    DdrWriteMiss = 2,
89    DdrPushCount = 3,
90    DdrReadCountChannel1 = 4,
91    DdrReadCountChannel2 = 5,
92    DdrReadCountChannel3 = 6,
93}
94
95#[derive(Copy, Clone, Debug)]
96// TODO: double check these are named correctly
97pub struct Statistics {
98    pub dropped_commands: u32,
99    pub ddr_read_miss: u32,
100    pub ddr_write_miss: u32,
101    pub ddr_push_count: u32,
102    pub ddr_read_count_channel_1: u32,
103    pub ddr_read_count_channel_2: u32,
104    pub ddr_read_count_channel_3: u32,
105}
106
107impl App {
108    pub fn new(fpga: Fpga, size: u8) -> Self {
109        assert!(size <= 27);
110        let pool = rayon::ThreadPoolBuilder::new()
111            .num_threads(2)
112            .build()
113            .unwrap();
114        let mut app = App {
115            fpga,
116            len: 1 << size,
117            pool: Some(pool),
118            carried: Some(vec![Scalar::default(); 1 << size]),
119        };
120        app.set_size();
121        app.set_first_bucket();
122        app.set_last_bucket();
123        app.set_ddr_read_len();
124        app.set_zero();
125
126        app
127    }
128
129    #[inline]
130    fn column<'a>(
131        &mut self,
132        i: usize,
133        scalars: impl Iterator<Item = &'a Scalar> + Clone + Send,
134        total: &mut G1TEProjective,
135    ) {
136        let mut cmds = Packet::default();
137        for j in (0..4).rev() {
138            timed(&format!("\n:: column {}", j as usize), || {
139                let mut stream = self.start_column();
140
141                let mut k = 0;
142                for scalar in scalars.clone() {
143                    let digit = single_digit_carry(scalar, i, j);
144                    cmds[k] = Command::set_digit(digit);
145                    k += 1;
146                    if k == 8 {
147                        stream.write(&cmds);
148                        k = 0;
149                    }
150                }
151                *total += timed("fetching point", || self.get_point());
152                if (i, j) != (0, 0) {
153                    shl_assign(total, 16);
154                }
155            });
156        }
157    }
158
159    /// Perform full MSM.
160    #[inline]
161    pub fn msm<'a>(
162        &mut self,
163        scalars: impl Iterator<Item = &'a Scalar> + Clone + ExactSizeIterator + Send,
164    ) -> G1Projective {
165        assert_eq!(scalars.len(), self.len as _);
166
167        let pool = self.pool.take().unwrap_or_else(|| unreachable!());
168        let mut carried = self.carried.take().unwrap_or_else(|| unreachable!());
169
170        let mut total = G1TEProjective::zero();
171        let mut total0 = G1TEProjective::zero();
172        let scalars_for_carry_calculation = scalars.clone();
173        let scalars_for_column_0_calculation = scalars;
174        pool.scope(|s| {
175            s.spawn(|_| {
176                timed("limb carries", || {
177                    limb_carries(scalars_for_carry_calculation, &mut carried)
178                });
179            });
180
181            s.spawn(|_| {
182                self.column(0, scalars_for_column_0_calculation, &mut total0);
183            });
184        });
185
186        for i in (1..4).rev() {
187            self.column(i, carried.iter(), &mut total);
188        }
189
190        shl_assign(&mut total, 48);
191        total += total0;
192
193        let total = into_weierstrass(&total);
194        self.pool = Some(pool);
195        self.carried = Some(carried);
196        total
197    }
198
199    /// Like `ark_ec::scalar_mul::variable_base::VariableBaseMSM::msm_bigint`
200    pub fn msm_bigint(&mut self, scalars: &[<Fr as ark_ff::PrimeField>::BigInt]) -> G1Projective {
201        self.msm(scalars.iter().map(|scalar| &scalar.0))
202    }
203
204    pub const fn len(&self) -> usize {
205        self.len
206    }
207
208    pub const fn is_empty(&self) -> bool {
209        self.len == 0
210    }
211
212    fn set_zero(&mut self) {
213        let zero = G1TEProjective::zero();
214        let mut packet = Packet::default();
215
216        let mut stream: FpgaStream<'_, NullBackoff> = self.fpga.stream(Stream::SetZero as _);
217
218        packet[..6].copy_from_slice(zero.x.0.as_ref());
219        stream.write(&packet);
220        packet[..6].copy_from_slice(zero.y.0.as_ref());
221        stream.write(&packet);
222        packet[..6].copy_from_slice(zero.z.0.as_ref());
223        stream.write(&packet);
224        packet[..6].copy_from_slice(zero.t.0.as_ref());
225        stream.write(&packet);
226
227        self.fpga.flush();
228    }
229
230    fn set_size(&mut self) {
231        self.fpga
232            .write(WriteRegister::MsmLength as _, &(self.len as u32));
233    }
234
235    fn set_last_bucket(&mut self) {
236        self.fpga
237            .write(WriteRegister::LastBucket as _, &LAST_BUCKET);
238    }
239
240    fn set_first_bucket(&mut self) {
241        self.fpga
242            .write(WriteRegister::FirstBucket as _, &FIRST_BUCKET);
243    }
244
245    fn set_ddr_read_len(&mut self) {
246        self.fpga
247            .write(WriteRegister::DdrReadLen as _, &DDR_READ_LEN);
248    }
249
250    #[inline]
251    fn set_coordinates(&mut self, coordinate: Stream, coordinates: impl Iterator<Item = Fq>) {
252        debug_assert!([
253            coordinate == Stream::SetX,
254            coordinate == Stream::SetY,
255            coordinate == Stream::SetKT
256        ]
257        .iter()
258        .any(|&condition| condition));
259        let mut packet = Packet::default();
260        let mut stream: FpgaStream<'_, SetPointsBackoff> = self.fpga.stream(coordinate as _);
261        for coordinate in coordinates {
262            packet[..6].copy_from_slice(coordinate.0.as_ref());
263            stream.write(&packet);
264        }
265    }
266    #[inline]
267    pub fn set_preprocessed_points(&mut self, points: &[G1PTEAffine]) {
268        assert!(self.len == points.len());
269
270        self.set_coordinates(Stream::SetX, points.iter().map(|point| point.x));
271        self.set_coordinates(Stream::SetY, points.iter().map(|point| point.y));
272        self.set_coordinates(Stream::SetKT, points.iter().map(|point| point.kt));
273    }
274
275    pub fn set_points(&mut self, points: &[G1Affine]) {
276        assert!(self.len == points.len());
277        let preprocessed_points: Vec<_> = points.iter().map(|point| point.into()).collect();
278        self.set_preprocessed_points(&preprocessed_points);
279    }
280
281    pub fn set_preprocessed_point_repeatedly(&mut self, point: &G1PTEAffine) {
282        self.set_coordinates(Stream::SetX, iter::repeat(point.x).take(self.len));
283        self.set_coordinates(Stream::SetY, iter::repeat(point.y).take(self.len));
284        self.set_coordinates(Stream::SetKT, iter::repeat(point.kt).take(self.len));
285    }
286
287    #[cfg(feature = "hw")]
288    fn get_coordinate(&mut self, coordinate: ReadRegister) -> Fq {
289        debug_assert!([
290            coordinate == ReadRegister::X,
291            coordinate == ReadRegister::Y,
292            coordinate == ReadRegister::Z,
293            coordinate == ReadRegister::T,
294        ]
295        .iter()
296        .any(|&condition| condition));
297        let mut buffer = [0u64; 6];
298        for j in (0..12).step_by(2) {
299            self.fpga.write(WriteRegister::Query as _, &j);
300            let lo = self.fpga.read(coordinate as _);
301
302            self.fpga.write(WriteRegister::Query as _, &(j + 1));
303            let hi = self.fpga.read(coordinate as _);
304
305            // | has lower precedence than <<, whereas + has higher
306            // and would need parentheses
307            buffer[j as usize / 2] = (hi as u64) << 32 | lo as u64;
308        }
309        ark_ff::BigInt(buffer).into()
310    }
311
312    #[cfg(feature = "hw")]
313    pub fn get_point(&mut self) -> G1TEProjective {
314        self.fpga.flush();
315        while 0 == self.fpga.read(ReadRegister::Aggregated as _) {
316            continue;
317        }
318
319        let mut point = G1TEProjective::zero();
320        point.x = self.get_coordinate(ReadRegister::X);
321        point.y = self.get_coordinate(ReadRegister::Y);
322        point.z = self.get_coordinate(ReadRegister::Z);
323        point.t = self.get_coordinate(ReadRegister::T);
324
325        point
326    }
327
328    #[cfg(not(feature = "hw"))]
329    pub fn get_point(&mut self) -> G1TEProjective {
330        G1TEProjective::zero()
331    }
332
333    pub fn statistics(&mut self) -> Statistics {
334        use Statistic::*;
335        Statistics {
336            dropped_commands: self.statistic(DroppedCommands),
337            ddr_read_miss: self.statistic(DdrReadMiss),
338            ddr_write_miss: self.statistic(DdrWriteMiss),
339            ddr_push_count: self.statistic(DdrPushCount),
340            ddr_read_count_channel_1: self.statistic(DdrReadCountChannel1),
341            ddr_read_count_channel_2: self.statistic(DdrReadCountChannel2),
342            ddr_read_count_channel_3: self.statistic(DdrReadCountChannel3),
343        }
344    }
345
346    pub fn statistic(&mut self, statistic: Statistic) -> u32 {
347        self.fpga
348            .write(WriteRegister::Query as _, &(statistic as u32));
349        self.fpga.read(ReadRegister::Statistic as _)
350    }
351
352    pub fn start_column(&mut self) -> FpgaStream<'_, DigitsBackoff> {
353        let mut stream = self.fpga.stream(Stream::Msm as _);
354
355        let mut packet = Packet::default();
356        packet[0] = Command::StartColumn as _;
357        stream.write(&packet);
358        stream.flush();
359        stream
360    }
361}
362
363pub struct SetPointsBackoff;
364impl fpga::Backoff<Fpga> for SetPointsBackoff {
365    #[inline(always)]
366    fn backoff(fpga: &mut Fpga, offset: usize) {
367        if (offset % SET_POINTS_FLUSH_EVERY) == 0 {
368            fpga.flush();
369        }
370    }
371}
372
373pub struct DigitsBackoff;
374impl fpga::Backoff<Fpga> for DigitsBackoff {
375    #[inline(always)]
376    fn backoff(fpga: &mut Fpga, offset: usize) {
377        if (offset % SET_DIGITS_FLUSH_BACKOFF_EVERY) == 0 {
378            fpga.flush();
379            while fpga.read(ReadRegister::DigitsQueue as _) > BACKOFF_THRESHOLD {
380                continue;
381            }
382        }
383    }
384}