1use std::{fmt::Debug, marker::PhantomData, rc::Rc};
16
17use risc0_core::field::Field;
18
19use super::{AccumPreflight, Buffer, CircuitHal, Hal};
20use crate::core::{digest::Digest, hash::HashSuite};
21
22#[derive(Clone)]
23pub struct BufferImpl<T, L, R>
24where
25 L: Buffer<T>,
26 R: Buffer<T>,
27{
28 lhs: L,
29 rhs: R,
30 phantom: PhantomData<T>,
31}
32
33impl<T, L, R> BufferImpl<T, L, R>
34where
35 T: Debug + PartialEq,
36 L: Buffer<T>,
37 R: Buffer<T>,
38{
39 pub fn new(lhs: L, rhs: R) -> Self {
40 Self {
41 lhs,
42 rhs,
43 phantom: PhantomData,
44 }
45 }
46
47 fn assert_eq(&self) {
48 self.lhs.view(|lhs| {
49 self.rhs.view(|rhs| {
50 assert_eq!(lhs.len(), rhs.len());
51 assert_eq!(lhs, rhs);
52 });
53 })
54 }
55}
56
57impl<T, L, R> Buffer<T> for BufferImpl<T, L, R>
58where
59 T: Clone + Debug + PartialEq,
60 L: Buffer<T>,
61 R: Buffer<T>,
62{
63 fn name(&self) -> &'static str {
64 "dual"
65 }
66
67 fn size(&self) -> usize {
68 let lhs = self.lhs.size();
69 let rhs = self.rhs.size();
70 assert_eq!(lhs, rhs);
71 lhs
72 }
73
74 fn slice(&self, offset: usize, size: usize) -> Self {
75 let lhs = self.lhs.slice(offset, size);
76 let rhs = self.rhs.slice(offset, size);
77 BufferImpl::new(lhs, rhs)
78 }
79
80 fn get_at(&self, idx: usize) -> T {
81 self.lhs.get_at(idx)
82 }
83
84 fn view<F: FnOnce(&[T])>(&self, f: F) {
85 self.lhs.view(f)
86 }
87
88 fn view_mut<F: FnOnce(&mut [T])>(&self, f: F) {
89 self.lhs.view_mut(f);
90 self.rhs.view_mut(|dst| {
91 self.lhs.view(|src| dst.clone_from_slice(src));
92 })
93 }
94
95 fn to_vec(&self) -> Vec<T> {
96 self.lhs.to_vec()
97 }
98}
99
100pub struct DualHal<F, L, R>
101where
102 L: Hal<Field = F>,
103 R: Hal<Field = F>,
104{
105 lhs: Rc<L>,
106 rhs: Rc<R>,
107}
108
109impl<F, L, R> DualHal<F, L, R>
110where
111 L: Hal<Field = F>,
112 R: Hal<Field = F>,
113{
114 pub fn new(lhs: Rc<L>, rhs: Rc<R>) -> Self {
115 Self { lhs, rhs }
116 }
117}
118
119impl<F, L, R> Hal for DualHal<F, L, R>
120where
121 F: Field,
122 L: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
123 R: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
124{
125 type Field = F;
126 type Elem = F::Elem;
127 type ExtElem = F::ExtElem;
128 type Buffer<T: Clone + Debug + PartialEq> = BufferImpl<T, L::Buffer<T>, R::Buffer<T>>;
129
130 fn get_hash_suite(&self) -> &HashSuite<Self::Field> {
131 self.lhs.get_hash_suite()
132 }
133
134 fn alloc_digest(&self, name: &'static str, size: usize) -> Self::Buffer<Digest> {
135 let lhs = self.lhs.alloc_digest(name, size);
136 let rhs = self.rhs.alloc_digest(name, size);
137 BufferImpl::new(lhs, rhs)
138 }
139
140 fn alloc_elem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::Elem> {
141 let lhs = self.lhs.alloc_elem(name, size);
142 let rhs = self.rhs.alloc_elem(name, size);
143 BufferImpl::new(lhs, rhs)
144 }
145
146 fn alloc_extelem(&self, name: &'static str, size: usize) -> Self::Buffer<Self::ExtElem> {
147 let lhs = self.lhs.alloc_extelem(name, size);
148 let rhs = self.rhs.alloc_extelem(name, size);
149 BufferImpl::new(lhs, rhs)
150 }
151
152 fn alloc_u32(&self, name: &'static str, size: usize) -> Self::Buffer<u32> {
153 let lhs = self.lhs.alloc_u32(name, size);
154 let rhs = self.rhs.alloc_u32(name, size);
155 BufferImpl::new(lhs, rhs)
156 }
157
158 fn alloc_elem_init(
159 &self,
160 name: &'static str,
161 size: usize,
162 value: Self::Elem,
163 ) -> Self::Buffer<Self::Elem> {
164 let lhs = self.lhs.alloc_elem_init(name, size, value);
165 let rhs = self.rhs.alloc_elem_init(name, size, value);
166 BufferImpl::new(lhs, rhs)
167 }
168
169 fn copy_from_digest(&self, name: &'static str, slice: &[Digest]) -> Self::Buffer<Digest> {
170 let lhs = self.lhs.copy_from_digest(name, slice);
171 let rhs = self.rhs.copy_from_digest(name, slice);
172 BufferImpl::new(lhs, rhs)
173 }
174
175 fn copy_from_elem(&self, name: &'static str, slice: &[Self::Elem]) -> Self::Buffer<Self::Elem> {
176 let lhs = self.lhs.copy_from_elem(name, slice);
177 let rhs = self.rhs.copy_from_elem(name, slice);
178 BufferImpl::new(lhs, rhs)
179 }
180
181 fn copy_from_extelem(
182 &self,
183 name: &'static str,
184 slice: &[Self::ExtElem],
185 ) -> Self::Buffer<Self::ExtElem> {
186 let lhs = self.lhs.copy_from_extelem(name, slice);
187 let rhs = self.rhs.copy_from_extelem(name, slice);
188 BufferImpl::new(lhs, rhs)
189 }
190
191 fn copy_from_u32(&self, name: &'static str, slice: &[u32]) -> Self::Buffer<u32> {
192 let lhs = self.lhs.copy_from_u32(name, slice);
193 let rhs = self.rhs.copy_from_u32(name, slice);
194 BufferImpl::new(lhs, rhs)
195 }
196
197 fn batch_expand_into_evaluate_ntt(
198 &self,
199 output: &Self::Buffer<Self::Elem>,
200 input: &Self::Buffer<Self::Elem>,
201 count: usize,
202 expand_bits: usize,
203 ) {
204 self.lhs
205 .batch_expand_into_evaluate_ntt(&output.lhs, &input.lhs, count, expand_bits);
206 self.rhs
207 .batch_expand_into_evaluate_ntt(&output.rhs, &input.rhs, count, expand_bits);
208 output.assert_eq();
209 }
210
211 fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
212 self.lhs.batch_interpolate_ntt(&io.lhs, count);
213 self.rhs.batch_interpolate_ntt(&io.rhs, count);
214 io.assert_eq();
215 }
216
217 fn batch_bit_reverse(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
218 self.lhs.batch_bit_reverse(&io.lhs, count);
219 self.rhs.batch_bit_reverse(&io.rhs, count);
220 io.assert_eq();
221 }
222
223 fn batch_evaluate_any(
224 &self,
225 coeffs: &Self::Buffer<Self::Elem>,
226 poly_count: usize,
227 which: &Self::Buffer<u32>,
228 xs: &Self::Buffer<Self::ExtElem>,
229 out: &Self::Buffer<Self::ExtElem>,
230 ) {
231 self.lhs
232 .batch_evaluate_any(&coeffs.lhs, poly_count, &which.lhs, &xs.lhs, &out.lhs);
233 self.rhs
234 .batch_evaluate_any(&coeffs.rhs, poly_count, &which.rhs, &xs.rhs, &out.rhs);
235 out.assert_eq();
236 }
237
238 fn zk_shift(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
239 self.lhs.zk_shift(&io.lhs, count);
240 self.rhs.zk_shift(&io.rhs, count);
241 io.assert_eq();
242 }
243
244 fn mix_poly_coeffs(
245 &self,
246 out: &Self::Buffer<Self::ExtElem>,
247 mix_start: &Self::ExtElem,
248 mix: &Self::ExtElem,
249 input: &Self::Buffer<Self::Elem>,
250 combos: &Self::Buffer<u32>,
251 input_size: usize,
252 count: usize,
253 ) {
254 self.lhs.mix_poly_coeffs(
255 &out.lhs,
256 mix_start,
257 mix,
258 &input.lhs,
259 &combos.lhs,
260 input_size,
261 count,
262 );
263 self.rhs.mix_poly_coeffs(
264 &out.rhs,
265 mix_start,
266 mix,
267 &input.rhs,
268 &combos.rhs,
269 input_size,
270 count,
271 );
272 out.assert_eq();
273 }
274
275 fn eltwise_add_elem(
276 &self,
277 output: &Self::Buffer<Self::Elem>,
278 input1: &Self::Buffer<Self::Elem>,
279 input2: &Self::Buffer<Self::Elem>,
280 ) {
281 self.lhs
282 .eltwise_add_elem(&output.lhs, &input1.lhs, &input2.lhs);
283 self.rhs
284 .eltwise_add_elem(&output.rhs, &input1.rhs, &input2.rhs);
285 output.assert_eq();
286 }
287
288 fn eltwise_sum_extelem(
289 &self,
290 output: &Self::Buffer<Self::Elem>,
291 input: &Self::Buffer<Self::ExtElem>,
292 ) {
293 self.lhs.eltwise_sum_extelem(&output.lhs, &input.lhs);
294 self.rhs.eltwise_sum_extelem(&output.rhs, &input.rhs);
295 output.assert_eq();
296 }
297
298 fn eltwise_copy_elem(
299 &self,
300 output: &Self::Buffer<Self::Elem>,
301 input: &Self::Buffer<Self::Elem>,
302 ) {
303 self.lhs.eltwise_copy_elem(&output.lhs, &input.lhs);
304 self.rhs.eltwise_copy_elem(&output.rhs, &input.rhs);
305 output.assert_eq();
306 }
307
308 fn eltwise_zeroize_elem(&self, elems: &Self::Buffer<Self::Elem>) {
309 self.lhs.eltwise_zeroize_elem(&elems.lhs);
310 self.rhs.eltwise_zeroize_elem(&elems.rhs);
311 elems.assert_eq();
312 }
313
314 fn fri_fold(
315 &self,
316 output: &Self::Buffer<Self::Elem>,
317 input: &Self::Buffer<Self::Elem>,
318 mix: &Self::ExtElem,
319 ) {
320 self.lhs.fri_fold(&output.lhs, &input.lhs, mix);
321 self.rhs.fri_fold(&output.rhs, &input.rhs, mix);
322 output.assert_eq();
323 }
324
325 fn hash_rows(&self, output: &Self::Buffer<Digest>, matrix: &Self::Buffer<Self::Elem>) {
326 self.lhs.hash_rows(&output.lhs, &matrix.lhs);
327 self.rhs.hash_rows(&output.rhs, &matrix.rhs);
328 output.assert_eq();
329 }
330
331 fn hash_fold(&self, io: &Self::Buffer<Digest>, input_size: usize, output_size: usize) {
332 self.lhs.hash_fold(&io.lhs, input_size, output_size);
333 self.rhs.hash_fold(&io.rhs, input_size, output_size);
334 io.assert_eq();
335 }
336
337 fn has_unified_memory(&self) -> bool {
338 self.rhs.has_unified_memory()
339 }
340
341 fn gather_sample(
342 &self,
343 dst: &Self::Buffer<Self::Elem>,
344 src: &Self::Buffer<Self::Elem>,
345 idx: usize,
346 size: usize,
347 stride: usize,
348 ) {
349 self.lhs
350 .gather_sample(&dst.lhs, &src.lhs, idx, size, stride);
351 self.rhs
352 .gather_sample(&dst.rhs, &src.rhs, idx, size, stride);
353 dst.assert_eq();
354 }
355
356 fn prefix_products(&self, io: &Self::Buffer<Self::ExtElem>) {
357 self.lhs.prefix_products(&io.lhs);
358 self.rhs.prefix_products(&io.rhs);
359 io.lhs.view(|lhs| {
362 io.rhs.view(|rhs| {
363 assert_eq!(lhs.len(), rhs.len());
364 for i in 0..lhs.len() {
365 assert_eq!(lhs[i], rhs[i], "{i}");
366 }
367 });
368 })
369 }
370
371 fn scatter(
372 &self,
373 into: &Self::Buffer<Self::Elem>,
374 index: &[u32],
375 offsets: &[u32],
376 values: &[Self::Elem],
377 ) {
378 self.lhs.scatter(&into.lhs, index, offsets, values);
379 self.rhs.scatter(&into.rhs, index, offsets, values);
380 into.assert_eq();
381 }
382
383 fn eltwise_copy_elem_slice(
384 &self,
385 into: &Self::Buffer<Self::Elem>,
386 from: &[Self::Elem],
387 from_rows: usize,
388 from_cols: usize,
389 from_offset: usize,
390 from_stride: usize,
391 into_offset: usize,
392 into_stride: usize,
393 ) {
394 self.lhs.eltwise_copy_elem_slice(
395 &into.lhs,
396 from,
397 from_rows,
398 from_cols,
399 from_offset,
400 from_stride,
401 into_offset,
402 into_stride,
403 );
404 self.rhs.eltwise_copy_elem_slice(
405 &into.rhs,
406 from,
407 from_rows,
408 from_cols,
409 from_offset,
410 from_stride,
411 into_offset,
412 into_stride,
413 );
414 into.assert_eq();
415 }
416}
417
418pub struct DualCircuitHal<F, LH, RH, LC, RC>
419where
420 F: Field,
421 LH: Hal<Field = F>,
422 RH: Hal<Field = F>,
423 LC: CircuitHal<LH>,
424 RC: CircuitHal<RH>,
425{
426 lhs: Rc<LC>,
427 rhs: Rc<RC>,
428 phantom: PhantomData<(LH, RH)>,
429}
430
431impl<F, LH, RH, LC, RC> DualCircuitHal<F, LH, RH, LC, RC>
432where
433 F: Field,
434 LH: Hal<Field = F>,
435 RH: Hal<Field = F>,
436 LC: CircuitHal<LH>,
437 RC: CircuitHal<RH>,
438{
439 pub fn new(lhs: Rc<LC>, rhs: Rc<RC>) -> Self {
440 Self {
441 lhs,
442 rhs,
443 phantom: PhantomData,
444 }
445 }
446}
447
448impl<F, LH, RH, LC, RC> CircuitHal<DualHal<F, LH, RH>> for DualCircuitHal<F, LH, RH, LC, RC>
449where
450 F: Field,
451 LH: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
452 RH: Hal<Field = F, Elem = F::Elem, ExtElem = F::ExtElem>,
453 LC: CircuitHal<LH>,
454 RC: CircuitHal<RH>,
455{
456 fn eval_check(
457 &self,
458 check: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
459 groups: &[&<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>],
460 globals: &[&<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>],
461 poly_mix: <DualHal<F, LH, RH> as Hal>::ExtElem,
462 po2: usize,
463 steps: usize,
464 ) {
465 let lhs_groups: Vec<&_> = groups.iter().map(|g| &g.lhs).collect();
466 let lhs_globals: Vec<&_> = globals.iter().map(|g| &g.lhs).collect();
467 let rhs_groups: Vec<&_> = groups.iter().map(|g| &g.rhs).collect();
468 let rhs_globals: Vec<&_> = globals.iter().map(|g| &g.rhs).collect();
469 self.lhs.eval_check(
470 &check.lhs,
471 lhs_groups.as_slice(),
472 lhs_globals.as_slice(),
473 poly_mix,
474 po2,
475 steps,
476 );
477 self.rhs.eval_check(
478 &check.rhs,
479 rhs_groups.as_slice(),
480 rhs_globals.as_slice(),
481 poly_mix,
482 po2,
483 steps,
484 );
485 check.assert_eq();
486 }
487
488 fn accumulate(
489 &self,
490 _preflight: &AccumPreflight,
491 _ctrl: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
492 _io: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
493 _data: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
494 _mix: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
495 _accum: &<DualHal<F, LH, RH> as Hal>::Buffer<F::Elem>,
496 _steps: usize,
497 ) {
498 todo!()
499 }
500}