risc0_zkp/hal/
dual.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, marker::PhantomData, rc::Rc};
16
17use risc0_core::field::Field;
18
19use super::{AccumPreflight, Buffer, CircuitHal, Hal};
20use crate::core::{digest::Digest, hash::HashSuite};
21
22#[derive(Clone)]
23pub struct BufferImpl<T, L, R>
24where
25    L: Buffer<T>,
26    R: Buffer<T>,
27{
28    lhs: L,
29    rhs: R,
30    phantom: PhantomData<T>,
31}
32
33impl<T, L, R> BufferImpl<T, L, R>
34where
35    T: Debug + PartialEq,
36    L: Buffer<T>,
37    R: Buffer<T>,
38{
39    pub fn new(lhs: L, rhs: R) -> Self {
40        Self {
41            lhs,
42            rhs,
43            phantom: PhantomData,
44        }
45    }
46
47    fn assert_eq(&self) {
48        self.lhs.view(|lhs| {
49            self.rhs.view(|rhs| {
50                assert_eq!(lhs.len(), rhs.len());
51                assert_eq!(lhs, rhs);
52            });
53        })
54    }
55}
56
57impl<T, L, R> Buffer<T> for BufferImpl<T, L, R>
58where
59    T: Clone + Debug + PartialEq,
60    L: Buffer<T>,
61    R: Buffer<T>,
62{
63    fn name(&self) -> &'static str {
64        "dual"
65    }
66
67    fn size(&self) -> usize {
68        let lhs = self.lhs.size();
69        let rhs = self.rhs.size();
70        assert_eq!(lhs, rhs);
71        lhs
72    }
73
74    fn slice(&self, offset: usize, size: usize) -> Self {
75        let lhs = self.lhs.slice(offset, size);
76        let rhs = self.rhs.slice(offset, size);
77        BufferImpl::new(lhs, rhs)
78    }
79
80    fn get_at(&self, idx: usize) -> T {
81        self.lhs.get_at(idx)
82    }
83
84    fn view<F: FnOnce(&[T])>(&self, f: F) {
85        self.lhs.view(f)
86    }
87
88    fn view_mut<F: FnOnce(&mut [T])>(&self, f: F) {
89        self.lhs.view_mut(f);
90        self.rhs.view_mut(|dst| {
91            self.lhs.view(|src| dst.clone_from_slice(src));
92        })
93    }
94
95    fn to_vec(&self) -> Vec<T> {
96        self.lhs.to_vec()
97    }
98}
99
100pub struct DualHal<F, L, R>
101where
102    L: Hal<Field = F>,
103    R: Hal<Field = F>,
104{
105    lhs: Rc<L>,
106    rhs: Rc<R>,
107}
108
109impl<F, L, R> DualHal<F, L, R>
110where
111    L: Hal<Field = F>,
112    R: Hal<Field = F>,
113{
114    pub fn new(lhs: Rc<L>, rhs: Rc<R>) -> Self {
115        Self { lhs, rhs }
116    }
117}
118
119impl<F, L, R> Hal for DualHal<F, L, R>
120where
121    F: Field,
122    L: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
123    R: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
124{
125    type Field = F;
126    type Elem = F::Elem;
127    type ExtElem = F::ExtElem;
128    type Buffer<T: Clone + Debug + PartialEq> = BufferImpl<T, L::Buffer<T>, R::Buffer<T>>;
129
130    fn get_hash_suite(&self) -> &HashSuite<Self::Field> {
131        self.lhs.get_hash_suite()
132    }
133
134    fn alloc_digest(&self, name: &'static str, size: usize) -> Self::Buffer<Digest> {
135        let lhs = self.lhs.alloc_digest(name, size);
136        let rhs = self.rhs.alloc_digest(name, size);
137        BufferImpl::new(lhs, rhs)
138    }
139
140    fn alloc_elem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::Elem> {
141        let lhs = self.lhs.alloc_elem(name, size);
142        let rhs = self.rhs.alloc_elem(name, size);
143        BufferImpl::new(lhs, rhs)
144    }
145
146    fn alloc_extelem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem> {
147        let lhs = self.lhs.alloc_extelem(name, size);
148        let rhs = self.rhs.alloc_extelem(name, size);
149        BufferImpl::new(lhs, rhs)
150    }
151
152    fn alloc_u32(&self, name: &'static str, size: usize) -> Self::Buffer<u32> {
153        let lhs = self.lhs.alloc_u32(name, size);
154        let rhs = self.rhs.alloc_u32(name, size);
155        BufferImpl::new(lhs, rhs)
156    }
157
158    fn alloc_elem_init(
159        &self,
160        name: &'static str,
161        size: usize,
162        value: Self::Elem,
163    ) -> Self::Buffer<Self::Elem> {
164        let lhs = self.lhs.alloc_elem_init(name, size, value);
165        let rhs = self.rhs.alloc_elem_init(name, size, value);
166        BufferImpl::new(lhs, rhs)
167    }
168
169    fn copy_from_digest(&self, name: &'static str, slice: &[Digest]) -> Self::Buffer<Digest> {
170        let lhs = self.lhs.copy_from_digest(name, slice);
171        let rhs = self.rhs.copy_from_digest(name, slice);
172        BufferImpl::new(lhs, rhs)
173    }
174
175    fn copy_from_elem(&self, name: &'static str, slice: &[Self::Elem]) -> Self::Buffer<Self::Elem> {
176        let lhs = self.lhs.copy_from_elem(name, slice);
177        let rhs = self.rhs.copy_from_elem(name, slice);
178        BufferImpl::new(lhs, rhs)
179    }
180
181    fn copy_from_extelem(
182        &self,
183        name: &'static str,
184        slice: &[Self::ExtElem],
185    ) -> Self::Buffer<Self::ExtElem> {
186        let lhs = self.lhs.copy_from_extelem(name, slice);
187        let rhs = self.rhs.copy_from_extelem(name, slice);
188        BufferImpl::new(lhs, rhs)
189    }
190
191    fn copy_from_u32(&self, name: &'static str, slice: &[u32]) -> Self::Buffer<u32> {
192        let lhs = self.lhs.copy_from_u32(name, slice);
193        let rhs = self.rhs.copy_from_u32(name, slice);
194        BufferImpl::new(lhs, rhs)
195    }
196
197    fn batch_expand_into_evaluate_ntt(
198        &self,
199        output: &Self::Buffer<Self::Elem>,
200        input: &Self::Buffer<Self::Elem>,
201        count: usize,
202        expand_bits: usize,
203    ) {
204        self.lhs
205            .batch_expand_into_evaluate_ntt(&output.lhs, &input.lhs, count, expand_bits);
206        self.rhs
207            .batch_expand_into_evaluate_ntt(&output.rhs, &input.rhs, count, expand_bits);
208        output.assert_eq();
209    }
210
211    fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
212        self.lhs.batch_interpolate_ntt(&io.lhs, count);
213        self.rhs.batch_interpolate_ntt(&io.rhs, count);
214        io.assert_eq();
215    }
216
217    fn batch_bit_reverse(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
218        self.lhs.batch_bit_reverse(&io.lhs, count);
219        self.rhs.batch_bit_reverse(&io.rhs, count);
220        io.assert_eq();
221    }
222
223    fn batch_evaluate_any(
224        &self,
225        coeffs: &Self::Buffer<Self::Elem>,
226        poly_count: usize,
227        which: &Self::Buffer<u32>,
228        xs: &Self::Buffer<Self::ExtElem>,
229        out: &Self::Buffer<Self::ExtElem>,
230    ) {
231        self.lhs
232            .batch_evaluate_any(&coeffs.lhs, poly_count, &which.lhs, &xs.lhs, &out.lhs);
233        self.rhs
234            .batch_evaluate_any(&coeffs.rhs, poly_count, &which.rhs, &xs.rhs, &out.rhs);
235        out.assert_eq();
236    }
237
238    fn zk_shift(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
239        self.lhs.zk_shift(&io.lhs, count);
240        self.rhs.zk_shift(&io.rhs, count);
241        io.assert_eq();
242    }
243
244    fn mix_poly_coeffs(
245        &self,
246        out: &Self::Buffer<Self::ExtElem>,
247        mix_start: &Self::ExtElem,
248        mix: &Self::ExtElem,
249        input: &Self::Buffer<Self::Elem>,
250        combos: &Self::Buffer<u32>,
251        input_size: usize,
252        count: usize,
253    ) {
254        self.lhs.mix_poly_coeffs(
255            &out.lhs,
256            mix_start,
257            mix,
258            &input.lhs,
259            &combos.lhs,
260            input_size,
261            count,
262        );
263        self.rhs.mix_poly_coeffs(
264            &out.rhs,
265            mix_start,
266            mix,
267            &input.rhs,
268            &combos.rhs,
269            input_size,
270            count,
271        );
272        out.assert_eq();
273    }
274
275    fn eltwise_add_elem(
276        &self,
277        output: &Self::Buffer<Self::Elem>,
278        input1: &Self::Buffer<Self::Elem>,
279        input2: &Self::Buffer<Self::Elem>,
280    ) {
281        self.lhs
282            .eltwise_add_elem(&output.lhs, &input1.lhs, &input2.lhs);
283        self.rhs
284            .eltwise_add_elem(&output.rhs, &input1.rhs, &input2.rhs);
285        output.assert_eq();
286    }
287
288    fn eltwise_sum_extelem(
289        &self,
290        output: &Self::Buffer<Self::Elem>,
291        input: &Self::Buffer<Self::ExtElem>,
292    ) {
293        self.lhs.eltwise_sum_extelem(&output.lhs, &input.lhs);
294        self.rhs.eltwise_sum_extelem(&output.rhs, &input.rhs);
295        output.assert_eq();
296    }
297
298    fn eltwise_copy_elem(
299        &self,
300        output: &Self::Buffer<Self::Elem>,
301        input: &Self::Buffer<Self::Elem>,
302    ) {
303        self.lhs.eltwise_copy_elem(&output.lhs, &input.lhs);
304        self.rhs.eltwise_copy_elem(&output.rhs, &input.rhs);
305        output.assert_eq();
306    }
307
308    fn eltwise_zeroize_elem(&self, elems: &Self::Buffer<Self::Elem>) {
309        self.lhs.eltwise_zeroize_elem(&elems.lhs);
310        self.rhs.eltwise_zeroize_elem(&elems.rhs);
311        elems.assert_eq();
312    }
313
314    fn fri_fold(
315        &self,
316        output: &Self::Buffer<Self::Elem>,
317        input: &Self::Buffer<Self::Elem>,
318        mix: &Self::ExtElem,
319    ) {
320        self.lhs.fri_fold(&output.lhs, &input.lhs, mix);
321        self.rhs.fri_fold(&output.rhs, &input.rhs, mix);
322        output.assert_eq();
323    }
324
325    fn hash_rows(&self, output: &Self::Buffer<Digest>, matrix: &Self::Buffer<Self::Elem>) {
326        self.lhs.hash_rows(&output.lhs, &matrix.lhs);
327        self.rhs.hash_rows(&output.rhs, &matrix.rhs);
328        output.assert_eq();
329    }
330
331    fn hash_fold(&self, io: &Self::Buffer<Digest>, input_size: usize, output_size: usize) {
332        self.lhs.hash_fold(&io.lhs, input_size, output_size);
333        self.rhs.hash_fold(&io.rhs, input_size, output_size);
334        io.assert_eq();
335    }
336
337    fn has_unified_memory(&self) -> bool {
338        self.rhs.has_unified_memory()
339    }
340
341    fn gather_sample(
342        &self,
343        dst: &Self::Buffer<Self::Elem>,
344        src: &Self::Buffer<Self::Elem>,
345        idx: usize,
346        size: usize,
347        stride: usize,
348    ) {
349        self.lhs
350            .gather_sample(&dst.lhs, &src.lhs, idx, size, stride);
351        self.rhs
352            .gather_sample(&dst.rhs, &src.rhs, idx, size, stride);
353        dst.assert_eq();
354    }
355
356    fn prefix_products(&self, io: &Self::Buffer<Self::ExtElem>) {
357        self.lhs.prefix_products(&io.lhs);
358        self.rhs.prefix_products(&io.rhs);
359        // io.assert_eq();
360
361        io.lhs.view(|lhs| {
362            io.rhs.view(|rhs| {
363                assert_eq!(lhs.len(), rhs.len());
364                for i in 0..lhs.len() {
365                    assert_eq!(lhs[i], rhs[i], "{i}");
366                }
367            });
368        })
369    }
370
371    fn scatter(
372        &self,
373        into: &Self::Buffer<Self::Elem>,
374        index: &[u32],
375        offsets: &[u32],
376        values: &[Self::Elem],
377    ) {
378        self.lhs.scatter(&into.lhs, index, offsets, values);
379        self.rhs.scatter(&into.rhs, index, offsets, values);
380        into.assert_eq();
381    }
382
383    fn eltwise_copy_elem_slice(
384        &self,
385        into: &Self::Buffer<Self::Elem>,
386        from: &[Self::Elem],
387        from_rows: usize,
388        from_cols: usize,
389        from_offset: usize,
390        from_stride: usize,
391        into_offset: usize,
392        into_stride: usize,
393    ) {
394        self.lhs.eltwise_copy_elem_slice(
395            &into.lhs,
396            from,
397            from_rows,
398            from_cols,
399            from_offset,
400            from_stride,
401            into_offset,
402            into_stride,
403        );
404        self.rhs.eltwise_copy_elem_slice(
405            &into.rhs,
406            from,
407            from_rows,
408            from_cols,
409            from_offset,
410            from_stride,
411            into_offset,
412            into_stride,
413        );
414        into.assert_eq();
415    }
416}
417
418pub struct DualCircuitHal<F, LH, RH, LC, RC>
419where
420    F: Field,
421    LH: Hal<Field = F>,
422    RH: Hal<Field = F>,
423    LC: CircuitHal<LH>,
424    RC: CircuitHal<RH>,
425{
426    lhs: Rc<LC>,
427    rhs: Rc<RC>,
428    phantom: PhantomData<(LH, RH)>,
429}
430
431impl<F, LH, RH, LC, RC> DualCircuitHal<F, LH, RH, LC, RC>
432where
433    F: Field,
434    LH: Hal<Field = F>,
435    RH: Hal<Field = F>,
436    LC: CircuitHal<LH>,
437    RC: CircuitHal<RH>,
438{
439    pub fn new(lhs: Rc<LC>, rhs: Rc<RC>) -> Self {
440        Self {
441            lhs,
442            rhs,
443            phantom: PhantomData,
444        }
445    }
446}
447
448impl<F, LH, RH, LC, RC> CircuitHal<DualHal<F, LH, RH>> for DualCircuitHal<F, LH, RH, LC, RC>
449where
450    F: Field,
451    LH: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
452    RH: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
453    LC: CircuitHal<LH>,
454    RC: CircuitHal<RH>,
455{
456    fn eval_check(
457        &self,
458        check: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
459        groups: &[&<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>],
460        globals: &[&<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>],
461        poly_mix: <DualHal<F, LH, RH> as Hal>::ExtElem,
462        po2: usize,
463        steps: usize,
464    ) {
465        let lhs_groups: Vec<&_> = groups.iter().map(|g| &g.lhs).collect();
466        let lhs_globals: Vec<&_> = globals.iter().map(|g| &g.lhs).collect();
467        let rhs_groups: Vec<&_> = groups.iter().map(|g| &g.rhs).collect();
468        let rhs_globals: Vec<&_> = globals.iter().map(|g| &g.rhs).collect();
469        self.lhs.eval_check(
470            &check.lhs,
471            lhs_groups.as_slice(),
472            lhs_globals.as_slice(),
473            poly_mix,
474            po2,
475            steps,
476        );
477        self.rhs.eval_check(
478            &check.rhs,
479            rhs_groups.as_slice(),
480            rhs_globals.as_slice(),
481            poly_mix,
482            po2,
483            steps,
484        );
485        check.assert_eq();
486    }
487
488    fn accumulate(
489        &self,
490        _preflight: &AccumPreflight,
491        _ctrl: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
492        _io: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
493        _data: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
494        _mix: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
495        _accum: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
496        _steps: usize,
497    ) {
498        todo!()
499    }
500}