1use std::{io, iter, marker::PhantomData, mem, panic::resume_unwind, task::Poll};
15
16use bytemuck::cast_slice_mut;
17use cryprot_core::{
18 Block,
19 aes_hash::FIXED_KEY_HASH,
20 aes_rng::AesRng,
21 alloc::allocate_zeroed_vec,
22 buf::Buf,
23 tokio_rayon::spawn_compute,
24 transpose::transpose_bitmatrix,
25 utils::{and_inplace_elem, xor_inplace},
26};
27use cryprot_net::{Connection, ConnectionError};
28use futures::{FutureExt, SinkExt, StreamExt, future::poll_fn};
29use rand::{Rng, RngExt, SeedableRng, distr::StandardUniform, rngs::StdRng};
30use subtle::{Choice, ConditionallySelectable};
31use tokio::{
32 io::{AsyncReadExt, AsyncWriteExt},
33 sync::mpsc,
34};
35use tracing::Level;
36
37use crate::{
38 BaseOt, BaseOtError, Connected, CotReceiver, CotSender, Malicious, MaliciousMarker,
39 RotReceiver, RotSender, Security, SemiHonest, SemiHonestMarker, adapter::CorrelatedFromRandom,
40 phase, random_choices,
41};
42
43pub const BASE_OT_COUNT: usize = 128;
44
45pub const DEFAULT_OT_BATCH_SIZE: usize = 2_usize.pow(16);
46
47pub struct OtExtensionSender<S> {
49 rng: StdRng,
50 base_ot: BaseOt,
51 conn: Connection,
52 base_rngs: Vec<AesRng>,
53 base_choices: Vec<Choice>,
54 delta: Option<Block>,
55 batch_size: usize,
56 security: PhantomData<S>,
57}
58
59pub struct OtExtensionReceiver<S> {
61 base_ot: BaseOt,
62 conn: Connection,
63 base_rngs: Vec<[AesRng; 2]>,
64 batch_size: usize,
65 security: PhantomData<S>,
66 rng: StdRng,
67}
68
69pub type SemiHonestOtExtensionSender = OtExtensionSender<SemiHonestMarker>;
71pub type SemiHonestOtExtensionReceiver = OtExtensionReceiver<SemiHonestMarker>;
73
74pub type MaliciousOtExtensionSender = OtExtensionSender<MaliciousMarker>;
76pub type MaliciousOtExtensionReceiver = OtExtensionReceiver<MaliciousMarker>;
78
79#[derive(thiserror::Error, Debug)]
81#[non_exhaustive]
82pub enum Error {
83 #[error("unable to compute base OTs")]
84 BaseOT(#[from] BaseOtError),
85 #[error("connection error to peer")]
86 Connection(#[from] ConnectionError),
87 #[error("error in sending/receiving data")]
88 Communication(#[from] io::Error),
89 #[error("connection closed by peer")]
90 UnexcpectedClose,
91 #[error("Commitment does not match seed")]
93 WrongCommitment,
94 #[error("sender did not receiver x value in KOS check")]
96 MissingXValue,
97 #[error("malicious check failed")]
99 MaliciousCheck,
100 #[doc(hidden)]
101 #[error("async task is dropped. This error should not be observable.")]
102 AsyncTaskDropped,
103}
104
105impl<S: Security> OtExtensionSender<S> {
106 pub fn new(conn: Connection) -> Self {
108 Self::new_with_rng(conn, rand::make_rng())
109 }
110
111 pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
115 let base_ot = BaseOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
116 Self {
117 rng,
118 base_ot,
119 conn,
120 base_rngs: vec![],
121 base_choices: vec![],
122 delta: None,
123 batch_size: DEFAULT_OT_BATCH_SIZE,
124 security: PhantomData,
125 }
126 }
127
128 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
135 self.batch_size = batch_size;
136 self
137 }
138
139 pub fn batch_size(&self) -> usize {
141 self.batch_size
142 }
143
144 pub fn has_base_ots(&self) -> bool {
147 self.base_rngs.len() == BASE_OT_COUNT
148 }
149
150 pub async fn do_base_ots(&mut self) -> Result<(), Error> {
153 let base_choices = random_choices(BASE_OT_COUNT, &mut self.rng);
154 let base_ots = self.base_ot.receive(&base_choices).await?;
155 self.base_rngs = base_ots.into_iter().map(AesRng::from_seed).collect();
156 self.delta = Some(Block::from_choices(&base_choices));
157 self.base_choices = base_choices;
158 Ok(())
159 }
160}
161
162impl<S: Security> Connected for OtExtensionSender<S> {
163 fn connection(&mut self) -> &mut Connection {
164 &mut self.conn
165 }
166}
167
168impl SemiHonest for OtExtensionSender<SemiHonestMarker> {}
169impl SemiHonest for OtExtensionSender<MaliciousMarker> {}
172
173impl Malicious for OtExtensionSender<MaliciousMarker> {}
174
175impl<S: Security> RotSender for OtExtensionSender<S> {
176 type Error = Error;
177
178 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
184 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
185 async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
186 let count = ots.len();
187 assert_eq!(0, count % 128, "count must be multiple of 128");
188 let batch_size = self.batch_size();
189 let batches = count / batch_size;
190 let batch_size_remainder = count % batch_size;
191 let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
192
193 assert_eq!(
194 0,
195 batch_size_remainder % 128,
196 "count % batch_size must be multiple of 128"
197 );
198
199 let batch_sizes = iter::repeat_n(batch_size, batches)
200 .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
201
202 if !self.has_base_ots() {
203 self.do_base_ots().await?;
204 }
205
206 let delta = self.delta.expect("base OTs are done");
207 let mut sub_conn = self.conn.sub_connection();
208
209 let (ch_s, ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
211 let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Block>();
212 let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
213 let mut base_rngs = mem::take(&mut self.base_rngs);
215 let base_choices = mem::take(&mut self.base_choices);
216 let batch_sizes_th = batch_sizes.clone();
217 let owned_ots = mem::take(ots);
218 let mut rng = StdRng::from_rng(&mut self.rng);
219
220 let jh = spawn_compute(move || {
223 let mut ots = owned_ots;
224 let mut extra_messages: Vec<[Block; 2]> = Vec::zeroed(num_extra);
225 let mut transposed = Vec::zeroed(batch_size);
226 let mut owned_v_mat: Vec<Block> = if S::MALICIOUS_SECURITY {
227 Vec::zeroed(ots.len())
228 } else {
229 vec![]
230 };
231 let mut extra_v_mat = vec![Block::ZERO; num_extra];
232
233 for (ots, batch_sizes, extra) in [
234 (
235 &mut ots[..],
236 &mut batch_sizes_th.clone() as &mut dyn Iterator<Item = _>,
237 false,
238 ),
239 (&mut extra_messages[..], &mut iter::once(num_extra), true),
240 ] {
241 for (chunk_idx, (ot_batch, curr_batch_size)) in
246 ots.chunks_mut(batch_size).zip(batch_sizes).enumerate()
247 {
248 let v_mat = if S::MALICIOUS_SECURITY {
249 if extra {
250 &mut extra_v_mat
251 } else {
252 let offset = chunk_idx * batch_size;
253 &mut owned_v_mat[offset..offset + curr_batch_size]
254 }
255 } else {
256 cast_slice_mut(&mut ot_batch[..curr_batch_size / 2])
260 };
261 let v_mat = cast_slice_mut(v_mat);
262
263 let cols_byte_batch = curr_batch_size / 8;
264 let row_iter = v_mat.chunks_exact_mut(cols_byte_batch);
265
266 for ((v_row, base_rng), base_choice) in
267 row_iter.zip(&mut base_rngs).zip(&base_choices)
268 {
269 base_rng.fill_bytes(v_row);
270 let mut recv_row = ch_r.recv()?;
271 let choice_mask =
276 Block::conditional_select(&Block::ZERO, &Block::ONES, *base_choice);
277 and_inplace_elem(&mut recv_row, choice_mask);
280 let v_row = bytemuck::cast_slice_mut(v_row);
281 xor_inplace(v_row, &recv_row);
284 }
285 {
286 let transposed = bytemuck::cast_slice_mut(&mut transposed);
287 transpose_bitmatrix(v_mat, &mut transposed[..v_mat.len()], BASE_OT_COUNT);
288 }
289
290 for (v, ots) in transposed.iter().zip(ot_batch.iter_mut()) {
291 *ots = [*v, *v ^ delta]
292 }
293
294 if S::MALICIOUS_SECURITY {
295 FIXED_KEY_HASH.tccr_hash_slice_mut(
296 bytemuck::must_cast_slice_mut(ot_batch),
297 |i| {
298 Block::from(chunk_idx * batch_size + (i / 2))
303 },
304 );
305 } else {
306 FIXED_KEY_HASH.cr_hash_slice_mut(bytemuck::must_cast_slice_mut(ot_batch));
307 }
308 }
309 }
310
311 if S::MALICIOUS_SECURITY {
312 let seed: Block = rng.random();
313 kos_ch_s.send(seed)?;
314 let rng = AesRng::from_seed(seed);
315
316 let mut q1 = extra_v_mat;
317 let mut q2 = vec![Block::ZERO; BASE_OT_COUNT];
318
319 let owned_v_mat_ref = &owned_v_mat;
320
321 let challenges: Vec<Block> = rng
322 .sample_iter(StandardUniform)
323 .take(ots.len() / BASE_OT_COUNT)
324 .collect();
325
326 let block_batch_size = batch_size / BASE_OT_COUNT;
327
328 let challenge_iter =
329 batch_sizes_th
330 .clone()
331 .enumerate()
332 .flat_map(|(batch, curr_batch_size)| {
333 challenges[batch * block_batch_size
334 ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
335 .iter()
336 .cycle()
337 .take(curr_batch_size)
338 });
339
340 let q_idx_iter = batch_sizes_th.flat_map(|curr_batch_size| {
341 (0..BASE_OT_COUNT).flat_map(move |t_idx| {
342 iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
343 })
344 });
345
346 for ((v, s), q_idx) in owned_v_mat_ref.iter().zip(challenge_iter).zip(q_idx_iter) {
347 let (qi, qi2) = v.clmul(s);
348 q1[q_idx] ^= qi;
349 q2[q_idx] ^= qi2;
350 }
351
352 for (q1i, q2i) in q1.iter_mut().zip(&q2) {
353 *q1i = Block::gf_reduce(q1i, q2i);
354 }
355 let mut u = kos_ch_r.recv()?;
356 let Some(received_x) = u.pop() else {
357 return Err(Error::MissingXValue);
358 };
359 for ((received_t, base_choice), q1i) in u.iter().zip(&base_choices).zip(&q1) {
360 let tt =
361 Block::conditional_select(&Block::ZERO, &received_x, *base_choice) ^ *q1i;
362 if tt != *received_t {
363 return Err(Error::MaliciousCheck);
364 }
365 }
366 }
367
368 Ok::<_, Error>((ots, base_rngs, base_choices))
369 });
370
371 let (_, mut recv) = sub_conn.byte_stream().await?;
372
373 for batch_size in batch_sizes.chain((num_extra != 0).then_some(num_extra)) {
374 for _ in 0..BASE_OT_COUNT {
375 let mut recv_row = allocate_zeroed_vec(batch_size / Block::BITS);
376 recv.read_exact(bytemuck::cast_slice_mut(&mut recv_row))
377 .await?;
378 if ch_s.send(recv_row).is_err() {
379 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
383 };
384 }
385 }
386
387 if S::MALICIOUS_SECURITY {
388 let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
389 let success = 'success: {
390 let Some(blk) = kos_ch_r_task.recv().await else {
391 break 'success false;
392 };
393 kos_send.as_stream().send(blk).await?;
394
395 {
396 let mut kos_recv = kos_recv.as_stream();
397 let u = kos_recv.next().await.ok_or(Error::UnexcpectedClose)??;
398 if kos_ch_s_task.send(u).is_err() {
399 break 'success false;
400 }
401 }
402
403 true
404 };
405 if !success {
406 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
407 }
408 }
409
410 let (owned_ots, base_rngs, base_choices) = match jh.await {
411 Ok(res) => res?,
412 Err(panicked) => resume_unwind(panicked),
413 };
414 self.base_rngs = base_rngs;
415 self.base_choices = base_choices;
416 *ots = owned_ots;
417 Ok(())
418 }
419}
420
421impl SemiHonest for OtExtensionReceiver<SemiHonestMarker> {}
422impl SemiHonest for OtExtensionReceiver<MaliciousMarker> {}
423
424impl Malicious for OtExtensionReceiver<MaliciousMarker> {}
425
426impl<S: Security> OtExtensionReceiver<S> {
427 pub fn new(conn: Connection) -> Self {
429 Self::new_with_rng(conn, rand::make_rng())
430 }
431
432 pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
436 let base_ot = BaseOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
437 Self {
438 rng,
439 base_ot,
440 conn,
441 base_rngs: vec![],
442 batch_size: DEFAULT_OT_BATCH_SIZE,
443 security: PhantomData,
444 }
445 }
446
447 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
454 self.batch_size = batch_size;
455 self
456 }
457
458 pub fn batch_size(&self) -> usize {
460 self.batch_size
461 }
462
463 pub fn has_base_ots(&self) -> bool {
466 self.base_rngs.len() == BASE_OT_COUNT
467 }
468
469 pub async fn do_base_ots(&mut self) -> Result<(), Error> {
472 let base_ots = self.base_ot.send(BASE_OT_COUNT).await?;
473 self.base_rngs = base_ots
474 .into_iter()
475 .map(|[s1, s2]| [AesRng::from_seed(s1), AesRng::from_seed(s2)])
476 .collect();
477 Ok(())
478 }
479}
480
481impl<S: Security> Connected for OtExtensionReceiver<S> {
482 fn connection(&mut self) -> &mut Connection {
483 &mut self.conn
484 }
485}
486
487impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
488 type Error = Error;
489
490 #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
496 #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
497 async fn receive_into(
498 &mut self,
499 ots: &mut impl Buf<Block>,
500 choices: &[Choice],
501 ) -> Result<(), Self::Error> {
502 assert_eq!(choices.len(), ots.len());
503 assert_eq!(
504 0,
505 choices.len() % 128,
506 "choices.len() must be multiple of 128"
507 );
508 let batch_size = self.batch_size();
509 let count = choices.len();
510 let batch_size_remainder = count % batch_size;
511 assert_eq!(
512 0,
513 batch_size_remainder % 128,
514 "count % batch_size must be multiple of 128"
515 );
516
517 if !self.has_base_ots() {
518 self.do_base_ots().await?;
519 }
520
521 let mut sub_conn = self.conn.sub_connection();
522
523 let cols_byte_batch = batch_size / 8;
524 let choice_vec = choices_to_u8_vec(choices);
525
526 let (ch_s, mut ch_r) = mpsc::unbounded_channel::<Vec<u8>>();
527 let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Vec<Block>>();
528 let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Block>();
529 let mut rng = StdRng::from_rng(&mut self.rng);
530
531 let mut base_rngs = mem::take(&mut self.base_rngs);
532 let owned_ots = mem::take(ots);
533 let mut jh = spawn_compute(move || {
534 let mut ots = owned_ots;
535 let t_mat_size = if S::MALICIOUS_SECURITY {
536 ots.len()
537 } else {
538 batch_size
539 };
540 let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
541 let mut t_mat = vec![Block::ZERO; t_mat_size];
542 let mut extra_t_mat = vec![Block::ZERO; num_extra];
543 let mut extra_messages: Vec<Block> = Vec::zeroed(num_extra);
544 let extra_choices = random_choices(num_extra, &mut rng);
545 let extra_choice_vec = choices_to_u8_vec(&extra_choices);
546
547 for (ots, choice_vec, extra) in [
548 (&mut ots[..], &choice_vec, false),
549 (&mut extra_messages[..], &extra_choice_vec, true),
550 ] {
551 for (chunk_idx, (output_chunk, choice_batch)) in ots
552 .chunks_mut(batch_size)
553 .zip(choice_vec.chunks(cols_byte_batch))
554 .enumerate()
555 {
556 let curr_batch_size = output_chunk.len();
557 let chunk_t_mat = if S::MALICIOUS_SECURITY {
558 if extra {
559 &mut extra_t_mat
560 } else {
561 let offset = chunk_idx * batch_size;
562 &mut t_mat[offset..offset + curr_batch_size]
563 }
564 } else {
565 &mut t_mat[..curr_batch_size]
566 };
567 assert_eq!(output_chunk.len(), chunk_t_mat.len());
568 assert_eq!(choice_batch.len() * 8, chunk_t_mat.len());
569 let chunk_t_mat: &mut [u8] = bytemuck::must_cast_slice_mut(chunk_t_mat);
570 let cols_byte_batch = choice_batch.len();
572 for (row, [rng1, rng2]) in chunk_t_mat
573 .chunks_exact_mut(cols_byte_batch)
574 .zip(&mut base_rngs)
575 {
576 rng1.fill_bytes(row);
577 let mut send_row = vec![0_u8; cols_byte_batch];
578 rng2.fill_bytes(&mut send_row);
579 for ((v2, v1), choices) in send_row.iter_mut().zip(row).zip(choice_batch) {
581 *v2 ^= *v1 ^ *choices;
582 }
583 ch_s.send(send_row)?;
584 }
585 let output_bytes = bytemuck::cast_slice_mut(output_chunk);
586 transpose_bitmatrix(
587 &chunk_t_mat[..BASE_OT_COUNT * cols_byte_batch],
588 output_bytes,
589 BASE_OT_COUNT,
590 );
591 if S::MALICIOUS_SECURITY {
592 FIXED_KEY_HASH.tccr_hash_slice_mut(output_chunk, |i| {
593 Block::from(chunk_idx * batch_size + i)
594 });
595 } else {
596 FIXED_KEY_HASH.cr_hash_slice_mut(output_chunk);
597 }
598 }
599 }
600
601 if S::MALICIOUS_SECURITY {
602 drop(ch_s);
604 let seed = kos_ch_r.recv()?;
605
606 let mut t1 = extra_t_mat;
607 let mut t2 = vec![Block::ZERO; BASE_OT_COUNT];
608
609 let mut x1 = Block::from_choices(&extra_choices);
610 let mut x2 = Block::ZERO;
611
612 let rng = AesRng::from_seed(seed);
613
614 let t_mat_ref = &t_mat;
615 let batches = count / batch_size;
616 let batch_sizes = iter::repeat_n(batch_size, batches)
617 .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
618
619 let choice_blocks: Vec<_> = choice_vec
620 .chunks_exact(Block::BYTES)
621 .map(|chunk| Block::try_from(chunk).expect("chunk is 16 bytes"))
622 .collect();
623
624 let challenges: Vec<Block> = rng
625 .sample_iter(StandardUniform)
626 .take(choice_blocks.len())
627 .collect();
628
629 for (x, s) in choice_blocks.iter().zip(challenges.iter()) {
630 let (xi, xi2) = x.clmul(s);
631 x1 ^= xi;
632 x2 ^= xi2;
633 }
634
635 let block_batch_size = batch_size / BASE_OT_COUNT;
636
637 let challenge_iter =
638 batch_sizes
639 .clone()
640 .enumerate()
641 .flat_map(|(batch, curr_batch_size)| {
642 challenges[batch * block_batch_size
643 ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
644 .iter()
645 .cycle()
646 .take(curr_batch_size)
647 });
648 let t_idx_iter = batch_sizes.flat_map(|curr_batch_size| {
649 (0..BASE_OT_COUNT).flat_map(move |t_idx| {
650 iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
651 })
652 });
653
654 for ((t, s), t_idx) in t_mat_ref.iter().zip(challenge_iter).zip(t_idx_iter) {
655 let (ti, ti2) = t.clmul(s);
656 t1[t_idx] ^= ti;
657 t2[t_idx] ^= ti2;
658 }
659
660 for (t1i, t2i) in t1.iter_mut().zip(&mut t2) {
661 *t1i = Block::gf_reduce(t1i, t2i);
662 }
663 t1.push(Block::gf_reduce(&x1, &x2));
664 kos_ch_s.send(t1)?;
665 }
666 Ok::<_, Error>((ots, base_rngs))
667 });
668
669 let (mut send, _) = sub_conn.byte_stream().await?;
670 while let Some(row) = ch_r.recv().await {
671 send.write_all(&row).await.map_err(Error::Communication)?;
672 }
673
674 if S::MALICIOUS_SECURITY {
675 let err = poll_fn(|cx| match jh.poll_unpin(cx) {
679 Poll::Ready(res) => Poll::Ready(res.map(drop)),
680 Poll::Pending => Poll::Ready(Ok(())),
681 })
682 .await;
683 if let Err(err) = err {
684 resume_unwind(err);
685 };
686 let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
687
688 let seed = {
689 let mut kos_recv = kos_recv.as_stream::<Block>();
690 kos_recv.next().await.ok_or(Error::UnexcpectedClose)??
691 };
692
693 let success = 'success: {
694 if kos_ch_s_task.send(seed).is_err() {
695 break 'success false;
696 }
697
698 let mut kos_send = kos_send.as_stream::<Vec<Block>>();
699 let Some(v) = kos_ch_r_task.recv().await else {
700 break 'success false;
701 };
702 kos_send.send(v).await.map_err(Error::Communication)?;
703
704 true
705 };
706 if !success {
707 resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
708 }
709 }
710
711 let (owned_ots, base_rngs) = match jh.await {
712 Ok(res) => res?,
713 Err(panicked) => resume_unwind(panicked),
714 };
715
716 self.base_rngs = base_rngs;
717 *ots = owned_ots;
718 Ok(())
719 }
720}
721
722impl<S: Security> CotSender for OtExtensionSender<S> {
723 type Error = Error;
724
725 async fn correlated_send_into<B, F>(
726 &mut self,
727 ots: &mut B,
728 correlation: F,
729 ) -> Result<(), Self::Error>
730 where
731 B: Buf<Block>,
732 F: FnMut(usize) -> Block + Send,
733 {
734 CorrelatedFromRandom::new(self)
735 .correlated_send_into(ots, correlation)
736 .await
737 }
738}
739
740impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
741 type Error = Error;
742
743 async fn correlated_receive_into<B>(
744 &mut self,
745 ots: &mut B,
746 choices: &[Choice],
747 ) -> Result<(), Self::Error>
748 where
749 B: Buf<Block>,
750 {
751 CorrelatedFromRandom::new(self)
752 .correlated_receive_into(ots, choices)
753 .await
754 }
755}
756
757fn choices_to_u8_vec(choices: &[Choice]) -> Vec<u8> {
758 assert_eq!(0, choices.len() % 8);
759 let mut v = vec![0_u8; choices.len() / 8];
760 for (chunk, byte) in choices.chunks_exact(8).zip(&mut v) {
761 for (i, choice) in chunk.iter().enumerate() {
762 *byte ^= choice.unwrap_u8() << i;
763 }
764 }
765 v
766}
767
768impl From<std::sync::mpsc::RecvError> for Error {
769 fn from(_: std::sync::mpsc::RecvError) -> Self {
770 Error::AsyncTaskDropped
771 }
772}
773
774impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
775 fn from(_: tokio::sync::mpsc::error::SendError<T>) -> Self {
776 Error::AsyncTaskDropped
777 }
778}
779
780#[cfg(test)]
781mod tests {
782
783 use cryprot_core::Block;
784 use cryprot_net::testing::{init_tracing, local_conn};
785 use rand::{SeedableRng, rngs::StdRng};
786
787 use crate::{
788 CotReceiver, CotSender, MaliciousMarker, RotReceiver, RotSender,
789 extension::{
790 DEFAULT_OT_BATCH_SIZE, OtExtensionReceiver, OtExtensionSender,
791 SemiHonestOtExtensionReceiver, SemiHonestOtExtensionSender,
792 },
793 random_choices,
794 };
795
796 #[tokio::test]
797 async fn test_extension() {
798 let _g = init_tracing();
799 const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE;
800 let (c1, c2) = local_conn().await.unwrap();
801 let rng1 = StdRng::seed_from_u64(42);
802 let mut rng2 = StdRng::seed_from_u64(24);
803 let choices = random_choices(COUNT, &mut rng2);
804 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
805 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
806 let (send_ots, recv_ots) =
807 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
808 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
809 assert_eq!(r, s[c.unwrap_u8() as usize]);
810 }
811 }
812
813 #[tokio::test]
814 async fn test_extension_half_batch() {
815 let _g = init_tracing();
816 const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2;
817 let (c1, c2) = local_conn().await.unwrap();
818 let rng1 = StdRng::seed_from_u64(42);
819 let mut rng2 = StdRng::seed_from_u64(24);
820 let choices = random_choices(COUNT, &mut rng2);
821 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
822 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
823 let (send_ots, recv_ots) =
824 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
825 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
826 assert_eq!(r, s[c.unwrap_u8() as usize]);
827 }
828 }
829
830 #[tokio::test]
831 async fn test_extension_partial_batch() {
832 let _g = init_tracing();
833 const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2 + 128;
834 let (c1, c2) = local_conn().await.unwrap();
835 let rng1 = StdRng::seed_from_u64(42);
836 let mut rng2 = StdRng::seed_from_u64(24);
837 let choices = random_choices(COUNT, &mut rng2);
838 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
839 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
840 let (send_ots, recv_ots) =
841 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
842 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
843 assert_eq!(r, s[c.unwrap_u8() as usize]);
844 }
845 }
846
847 #[tokio::test]
848 async fn test_extension_malicious_half_batch() {
849 let _g = init_tracing();
850 const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2;
851 let (c1, c2) = local_conn().await.unwrap();
852 let rng1 = StdRng::seed_from_u64(42);
853 let mut rng2 = StdRng::seed_from_u64(24);
854 let choices = random_choices(COUNT, &mut rng2);
855 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
856 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
857
858 let (send_ots, recv_ots) =
859 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
860 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
861 assert_eq!(r, s[c.unwrap_u8() as usize]);
862 }
863 }
864
865 #[tokio::test]
866 async fn test_extension_malicious_partial_batch() {
867 let _g = init_tracing();
868 const COUNT: usize = DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2 + 128;
869 let (c1, c2) = local_conn().await.unwrap();
870 let rng1 = StdRng::seed_from_u64(42);
871 let mut rng2 = StdRng::seed_from_u64(24);
872 let choices = random_choices(COUNT, &mut rng2);
873 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
874 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
875
876 let (send_ots, recv_ots) =
877 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
878 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
879 assert_eq!(r, s[c.unwrap_u8() as usize]);
880 }
881 }
882
883 #[tokio::test]
884 async fn test_extension_malicious_multiple_batch() {
885 let _g = init_tracing();
886 const COUNT: usize = DEFAULT_OT_BATCH_SIZE * 2;
887 let (c1, c2) = local_conn().await.unwrap();
888 let rng1 = StdRng::seed_from_u64(42);
889 let mut rng2 = StdRng::seed_from_u64(24);
890 let choices = random_choices(COUNT, &mut rng2);
891 let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
892 let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
893
894 let (send_ots, recv_ots) =
895 tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
896 for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
897 assert_eq!(r, s[c.unwrap_u8() as usize]);
898 }
899 }
900
901 #[tokio::test]
902 async fn test_correlated_extension() {
903 let _g = init_tracing();
904 const COUNT: usize = 128;
905 let (c1, c2) = local_conn().await.unwrap();
906 let rng1 = StdRng::seed_from_u64(42);
907 let mut rng2 = StdRng::seed_from_u64(24);
908 let choices = random_choices(COUNT, &mut rng2);
909 let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
910 let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
911 let (send_ots, recv_ots) = tokio::try_join!(
912 sender.correlated_send(COUNT, |_| Block::ONES),
913 receiver.correlated_receive(&choices)
914 )
915 .unwrap();
916 for (i, ((r, s), c)) in recv_ots.into_iter().zip(send_ots).zip(choices).enumerate() {
917 if bool::from(c) {
918 assert_eq!(r ^ Block::ONES, s, "Block {i}");
919 } else {
920 assert_eq!(r, s, "Block {i}")
921 }
922 }
923 }
924}