Skip to main content

cryprot_ot/
extension.rs

1//! Fast OT extension using optimized [[IKNP03](https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf)] (semi-honest)
2//! or [[KOS15](https://eprint.iacr.org/2015/546.pdf)] (malicious) protocol.
3//!
4//! The protocols are optimized for the availability of `aes` and `avx2` target
5//! features for the semi-honest protocol and additionally `pclmulqdq` for the
6//! malicious protocol.
7//!
8//! ## Batching
9//! The protocols automatically compute the OTs in batches to increase
10//! throughput. The [`DEFAULT_OT_BATCH_SIZE`] has been chosen to maximise
11//! throughput in very low latency settings for large numbers of OTs.
12//! The batch size can changed using the corresponding methods on the sender and
13//! receiver (e.g. [`OtExtensionSender::with_batch_size`]).
14use std::{io, iter, marker::PhantomData, mem, panic::resume_unwind, task::Poll};
15
16use bytemuck::cast_slice_mut;
17use cryprot_core::{
18    Block,
19    aes_hash::FIXED_KEY_HASH,
20    aes_rng::AesRng,
21    alloc::allocate_zeroed_vec,
22    buf::Buf,
23    tokio_rayon::spawn_compute,
24    transpose::transpose_bitmatrix,
25    utils::{and_inplace_elem, xor_inplace},
26};
27use cryprot_net::{Connection, ConnectionError};
28use futures::{FutureExt, SinkExt, StreamExt, future::poll_fn};
29use rand::{Rng, RngExt, SeedableRng, distr::StandardUniform, rngs::StdRng};
30use subtle::{Choice, ConditionallySelectable};
31use tokio::{
32    io::{AsyncReadExt, AsyncWriteExt},
33    sync::mpsc,
34};
35use tracing::Level;
36
37use crate::{
38    BaseOt, BaseOtError, Connected, CotReceiver, CotSender, Malicious, MaliciousMarker,
39    RotReceiver, RotSender, Security, SemiHonest, SemiHonestMarker, adapter::CorrelatedFromRandom,
40    phase, random_choices,
41};
42
43pub const BASE_OT_COUNT: usize = 128;
44
45pub const DEFAULT_OT_BATCH_SIZE: usize = 2_usize.pow(16);
46
47/// OT extension sender generic over its [`Security`] level.
48pub struct OtExtensionSender<S> {
49    rng: StdRng,
50    base_ot: BaseOt,
51    conn: Connection,
52    base_rngs: Vec<AesRng>,
53    base_choices: Vec<Choice>,
54    delta: Option<Block>,
55    batch_size: usize,
56    security: PhantomData<S>,
57}
58
59/// OT extension receiver generic over its [`Security`] level.
60pub struct OtExtensionReceiver<S> {
61    base_ot: BaseOt,
62    conn: Connection,
63    base_rngs: Vec<[AesRng; 2]>,
64    batch_size: usize,
65    security: PhantomData<S>,
66    rng: StdRng,
67}
68
69/// SemiHonest OT extension sender alias.
70pub type SemiHonestOtExtensionSender = OtExtensionSender<SemiHonestMarker>;
71/// SemiHonest OT extension receiver alias.
72pub type SemiHonestOtExtensionReceiver = OtExtensionReceiver<SemiHonestMarker>;
73
74/// Malicious OT extension sender alias.
75pub type MaliciousOtExtensionSender = OtExtensionSender<MaliciousMarker>;
76/// Malicious OT extension receiver alias.
77pub type MaliciousOtExtensionReceiver = OtExtensionReceiver<MaliciousMarker>;
78
79/// Error type returned by the OT extension protocols.
80#[derive(thiserror::Error, Debug)]
81#[non_exhaustive]
82pub enum Error {
83    #[error("unable to compute base OTs")]
84    BaseOT(#[from] BaseOtError),
85    #[error("connection error to peer")]
86    Connection(#[from] ConnectionError),
87    #[error("error in sending/receiving data")]
88    Communication(#[from] io::Error),
89    #[error("connection closed by peer")]
90    UnexcpectedClose,
91    /// Only possible for malicious variant.
92    #[error("Commitment does not match seed")]
93    WrongCommitment,
94    /// Only possible for malicious variant.
95    #[error("sender did not receiver x value in KOS check")]
96    MissingXValue,
97    /// Only possible for malicious variant.
98    #[error("malicious check failed")]
99    MaliciousCheck,
100    #[doc(hidden)]
101    #[error("async task is dropped. This error should not be observable.")]
102    AsyncTaskDropped,
103}
104
105impl<S: Security> OtExtensionSender<S> {
106    /// Create a new sender for the given [`Connection`].
107    pub fn new(conn: Connection) -> Self {
108        Self::new_with_rng(conn, rand::make_rng())
109    }
110
111    /// Create a new sender for the given [`Connection`] and [`StdRng`].
112    ///
113    /// For an rng seeded with a fixed seed, the output is deterministic.
114    pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
115        let base_ot = BaseOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
116        Self {
117            rng,
118            base_ot,
119            conn,
120            base_rngs: vec![],
121            base_choices: vec![],
122            delta: None,
123            batch_size: DEFAULT_OT_BATCH_SIZE,
124            security: PhantomData,
125        }
126    }
127
128    /// Set the OT batch size for the sender.
129    ///
130    /// If the sender batch size is changed, the receiver's must also be changed
131    /// (see [`OtExtensionReceiver::with_batch_size`]).
132    /// Note that [`OtExtensionSender::send`] methods will fail if `count %
133    /// self.batch_size()` is not divisable by 128.
134    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
135        self.batch_size = batch_size;
136        self
137    }
138
139    /// The currently configured OT batch size.
140    pub fn batch_size(&self) -> usize {
141        self.batch_size
142    }
143
144    /// Returns true if base OTs have been performed. Subsequent calls to send
145    /// will not perform base OTs again.
146    pub fn has_base_ots(&self) -> bool {
147        self.base_rngs.len() == BASE_OT_COUNT
148    }
149
150    /// Perform base OTs for later extension. Subsequent calls to send
151    /// will not perform base OTs again.
152    pub async fn do_base_ots(&mut self) -> Result<(), Error> {
153        let base_choices = random_choices(BASE_OT_COUNT, &mut self.rng);
154        let base_ots = self.base_ot.receive(&base_choices).await?;
155        self.base_rngs = base_ots.into_iter().map(AesRng::from_seed).collect();
156        self.delta = Some(Block::from_choices(&base_choices));
157        self.base_choices = base_choices;
158        Ok(())
159    }
160}
161
162impl<S: Security> Connected for OtExtensionSender<S> {
163    fn connection(&mut self) -> &mut Connection {
164        &mut self.conn
165    }
166}
167
168impl SemiHonest for OtExtensionSender<SemiHonestMarker> {}
169/// A maliciously secure sender also offers semi-honest security at decreased
170/// performance.
171impl SemiHonest for OtExtensionSender<MaliciousMarker> {}
172
173impl Malicious for OtExtensionSender<MaliciousMarker> {}
174
175impl<S: Security> RotSender for OtExtensionSender<S> {
176    type Error = Error;
177
178    /// Sender part of OT extension.
179    ///
180    /// # Panics
181    /// - If `count` is not divisable by 128.
182    /// - If `count % self.batch_size()` is not divisable by 128.
183    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
184    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
185    async fn send_into(&mut self, ots: &mut impl Buf<[Block; 2]>) -> Result<(), Self::Error> {
186        let count = ots.len();
187        assert_eq!(0, count % 128, "count must be multiple of 128");
188        let batch_size = self.batch_size();
189        let batches = count / batch_size;
190        let batch_size_remainder = count % batch_size;
191        let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
192
193        assert_eq!(
194            0,
195            batch_size_remainder % 128,
196            "count % batch_size must be multiple of 128"
197        );
198
199        let batch_sizes = iter::repeat_n(batch_size, batches)
200            .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
201
202        if !self.has_base_ots() {
203            self.do_base_ots().await?;
204        }
205
206        let delta = self.delta.expect("base OTs are done");
207        let mut sub_conn = self.conn.sub_connection();
208
209        // channel for communication between async task and compute thread
210        let (ch_s, ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
211        let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Block>();
212        let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Vec<Block>>();
213        // take these to move them into compute thread, will be returned via ret channel
214        let mut base_rngs = mem::take(&mut self.base_rngs);
215        let base_choices = mem::take(&mut self.base_choices);
216        let batch_sizes_th = batch_sizes.clone();
217        let owned_ots = mem::take(ots);
218        let mut rng = StdRng::from_rng(&mut self.rng);
219
220        // spawn compute thread for CPU intensive work. This way we increase throughput
221        // and don't risk of blocking tokio worker threads
222        let jh = spawn_compute(move || {
223            let mut ots = owned_ots;
224            let mut extra_messages: Vec<[Block; 2]> = Vec::zeroed(num_extra);
225            let mut transposed = Vec::zeroed(batch_size);
226            let mut owned_v_mat: Vec<Block> = if S::MALICIOUS_SECURITY {
227                Vec::zeroed(ots.len())
228            } else {
229                vec![]
230            };
231            let mut extra_v_mat = vec![Block::ZERO; num_extra];
232
233            for (ots, batch_sizes, extra) in [
234                (
235                    &mut ots[..],
236                    &mut batch_sizes_th.clone() as &mut dyn Iterator<Item = _>,
237                    false,
238                ),
239                (&mut extra_messages[..], &mut iter::once(num_extra), true),
240            ] {
241                // to increase throughput, we divide the `count` many OTs into batches of size
242                // self.batch_size(). Crucially, this allows us to do the transpose
243                // and hash step while not having received the complete data from the
244                // OtExtensionReceiver.
245                for (chunk_idx, (ot_batch, curr_batch_size)) in
246                    ots.chunks_mut(batch_size).zip(batch_sizes).enumerate()
247                {
248                    let v_mat = if S::MALICIOUS_SECURITY {
249                        if extra {
250                            &mut extra_v_mat
251                        } else {
252                            let offset = chunk_idx * batch_size;
253                            &mut owned_v_mat[offset..offset + curr_batch_size]
254                        }
255                    } else {
256                        // we temporarily use the output OT buffer to hold the current chunk of the
257                        // V matrix which we XOR with our received row or 0
258                        // and then transpose into `transposed`
259                        cast_slice_mut(&mut ot_batch[..curr_batch_size / 2])
260                    };
261                    let v_mat = cast_slice_mut(v_mat);
262
263                    let cols_byte_batch = curr_batch_size / 8;
264                    let row_iter = v_mat.chunks_exact_mut(cols_byte_batch);
265
266                    for ((v_row, base_rng), base_choice) in
267                        row_iter.zip(&mut base_rngs).zip(&base_choices)
268                    {
269                        base_rng.fill_bytes(v_row);
270                        let mut recv_row = ch_r.recv()?;
271                        // constant time version of
272                        // if !base_choice {
273                        //   v_row ^= recv_row;
274                        // }
275                        let choice_mask =
276                            Block::conditional_select(&Block::ZERO, &Block::ONES, *base_choice);
277                        // if choice_mask == 0, we zero out recv_row
278                        // if choice_mask == 1, recv_row is not changed
279                        and_inplace_elem(&mut recv_row, choice_mask);
280                        let v_row = bytemuck::cast_slice_mut(v_row);
281                        // if choice_mask == 0, v_row = v_row ^ 000000..
282                        // if choice_mask == 1, v_row = v_row ^ recv_row
283                        xor_inplace(v_row, &recv_row);
284                    }
285                    {
286                        let transposed = bytemuck::cast_slice_mut(&mut transposed);
287                        transpose_bitmatrix(v_mat, &mut transposed[..v_mat.len()], BASE_OT_COUNT);
288                    }
289
290                    for (v, ots) in transposed.iter().zip(ot_batch.iter_mut()) {
291                        *ots = [*v, *v ^ delta]
292                    }
293
294                    if S::MALICIOUS_SECURITY {
295                        FIXED_KEY_HASH.tccr_hash_slice_mut(
296                            bytemuck::must_cast_slice_mut(ot_batch),
297                            |i| {
298                                // use batch_size here, which is the batch_size of all batches
299                                // except potentially the last. If we use curr_batch_size, our
300                                // offset would be wrong for the last batch if curr_batch_size <
301                                // batch_size
302                                Block::from(chunk_idx * batch_size + (i / 2))
303                            },
304                        );
305                    } else {
306                        FIXED_KEY_HASH.cr_hash_slice_mut(bytemuck::must_cast_slice_mut(ot_batch));
307                    }
308                }
309            }
310
311            if S::MALICIOUS_SECURITY {
312                let seed: Block = rng.random();
313                kos_ch_s.send(seed)?;
314                let rng = AesRng::from_seed(seed);
315
316                let mut q1 = extra_v_mat;
317                let mut q2 = vec![Block::ZERO; BASE_OT_COUNT];
318
319                let owned_v_mat_ref = &owned_v_mat;
320
321                let challenges: Vec<Block> = rng
322                    .sample_iter(StandardUniform)
323                    .take(ots.len() / BASE_OT_COUNT)
324                    .collect();
325
326                let block_batch_size = batch_size / BASE_OT_COUNT;
327
328                let challenge_iter =
329                    batch_sizes_th
330                        .clone()
331                        .enumerate()
332                        .flat_map(|(batch, curr_batch_size)| {
333                            challenges[batch * block_batch_size
334                                ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
335                                .iter()
336                                .cycle()
337                                .take(curr_batch_size)
338                        });
339
340                let q_idx_iter = batch_sizes_th.flat_map(|curr_batch_size| {
341                    (0..BASE_OT_COUNT).flat_map(move |t_idx| {
342                        iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
343                    })
344                });
345
346                for ((v, s), q_idx) in owned_v_mat_ref.iter().zip(challenge_iter).zip(q_idx_iter) {
347                    let (qi, qi2) = v.clmul(s);
348                    q1[q_idx] ^= qi;
349                    q2[q_idx] ^= qi2;
350                }
351
352                for (q1i, q2i) in q1.iter_mut().zip(&q2) {
353                    *q1i = Block::gf_reduce(q1i, q2i);
354                }
355                let mut u = kos_ch_r.recv()?;
356                let Some(received_x) = u.pop() else {
357                    return Err(Error::MissingXValue);
358                };
359                for ((received_t, base_choice), q1i) in u.iter().zip(&base_choices).zip(&q1) {
360                    let tt =
361                        Block::conditional_select(&Block::ZERO, &received_x, *base_choice) ^ *q1i;
362                    if tt != *received_t {
363                        return Err(Error::MaliciousCheck);
364                    }
365                }
366            }
367
368            Ok::<_, Error>((ots, base_rngs, base_choices))
369        });
370
371        let (_, mut recv) = sub_conn.byte_stream().await?;
372
373        for batch_size in batch_sizes.chain((num_extra != 0).then_some(num_extra)) {
374            for _ in 0..BASE_OT_COUNT {
375                let mut recv_row = allocate_zeroed_vec(batch_size / Block::BITS);
376                recv.read_exact(bytemuck::cast_slice_mut(&mut recv_row))
377                    .await?;
378                if ch_s.send(recv_row).is_err() {
379                    // If we can't send on the channel, the channel must've been dropped due to a
380                    // panic in the worker thread. So we try to join the compute task to resume the
381                    // panic
382                    resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
383                };
384            }
385        }
386
387        if S::MALICIOUS_SECURITY {
388            let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
389            let success = 'success: {
390                let Some(blk) = kos_ch_r_task.recv().await else {
391                    break 'success false;
392                };
393                kos_send.as_stream().send(blk).await?;
394
395                {
396                    let mut kos_recv = kos_recv.as_stream();
397                    let u = kos_recv.next().await.ok_or(Error::UnexcpectedClose)??;
398                    if kos_ch_s_task.send(u).is_err() {
399                        break 'success false;
400                    }
401                }
402
403                true
404            };
405            if !success {
406                resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
407            }
408        }
409
410        let (owned_ots, base_rngs, base_choices) = match jh.await {
411            Ok(res) => res?,
412            Err(panicked) => resume_unwind(panicked),
413        };
414        self.base_rngs = base_rngs;
415        self.base_choices = base_choices;
416        *ots = owned_ots;
417        Ok(())
418    }
419}
420
421impl SemiHonest for OtExtensionReceiver<SemiHonestMarker> {}
422impl SemiHonest for OtExtensionReceiver<MaliciousMarker> {}
423
424impl Malicious for OtExtensionReceiver<MaliciousMarker> {}
425
426impl<S: Security> OtExtensionReceiver<S> {
427    /// Create a new sender for the given [`Connection`].
428    pub fn new(conn: Connection) -> Self {
429        Self::new_with_rng(conn, rand::make_rng())
430    }
431
432    /// Create a new sender for the given [`Connection`] and [`StdRng`].
433    ///
434    /// For an rng seeded with a fixed seed, the output is deterministic.
435    pub fn new_with_rng(mut conn: Connection, mut rng: StdRng) -> Self {
436        let base_ot = BaseOt::new_with_rng(conn.sub_connection(), StdRng::from_rng(&mut rng));
437        Self {
438            rng,
439            base_ot,
440            conn,
441            base_rngs: vec![],
442            batch_size: DEFAULT_OT_BATCH_SIZE,
443            security: PhantomData,
444        }
445    }
446
447    /// Set the OT batch size for the receiver.
448    ///
449    /// If the receiver batch size is changed, the senders's must also be
450    /// changed (see [`OtExtensionSender::with_batch_size`]).
451    /// Note that [`OtExtensionReceiver::receive`] methods will fail if `count %
452    /// self.batch_size()` is not divisable by 128.
453    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
454        self.batch_size = batch_size;
455        self
456    }
457
458    /// The currently configured OT batch size.
459    pub fn batch_size(&self) -> usize {
460        self.batch_size
461    }
462
463    /// Returns true if base OTs have been performed. Subsequent calls to send
464    /// will not perform base OTs again.
465    pub fn has_base_ots(&self) -> bool {
466        self.base_rngs.len() == BASE_OT_COUNT
467    }
468
469    /// Perform base OTs for later extension. Subsequent calls to send
470    /// will not perform base OTs again.
471    pub async fn do_base_ots(&mut self) -> Result<(), Error> {
472        let base_ots = self.base_ot.send(BASE_OT_COUNT).await?;
473        self.base_rngs = base_ots
474            .into_iter()
475            .map(|[s1, s2]| [AesRng::from_seed(s1), AesRng::from_seed(s2)])
476            .collect();
477        Ok(())
478    }
479}
480
481impl<S: Security> Connected for OtExtensionReceiver<S> {
482    fn connection(&mut self) -> &mut Connection {
483        &mut self.conn
484    }
485}
486
487impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
488    type Error = Error;
489
490    /// Receiver part of OT extension.
491    ///
492    /// # Panics
493    /// - If `choices.len()` is not divisable by 128.
494    /// - If `choices.len() % self.batch_size()` is not divisable by 128.
495    #[tracing::instrument(level = Level::DEBUG, skip_all, fields(count = ots.len()))]
496    #[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
497    async fn receive_into(
498        &mut self,
499        ots: &mut impl Buf<Block>,
500        choices: &[Choice],
501    ) -> Result<(), Self::Error> {
502        assert_eq!(choices.len(), ots.len());
503        assert_eq!(
504            0,
505            choices.len() % 128,
506            "choices.len() must be multiple of 128"
507        );
508        let batch_size = self.batch_size();
509        let count = choices.len();
510        let batch_size_remainder = count % batch_size;
511        assert_eq!(
512            0,
513            batch_size_remainder % 128,
514            "count % batch_size must be multiple of 128"
515        );
516
517        if !self.has_base_ots() {
518            self.do_base_ots().await?;
519        }
520
521        let mut sub_conn = self.conn.sub_connection();
522
523        let cols_byte_batch = batch_size / 8;
524        let choice_vec = choices_to_u8_vec(choices);
525
526        let (ch_s, mut ch_r) = mpsc::unbounded_channel::<Vec<u8>>();
527        let (kos_ch_s, mut kos_ch_r_task) = tokio::sync::mpsc::unbounded_channel::<Vec<Block>>();
528        let (kos_ch_s_task, kos_ch_r) = std::sync::mpsc::channel::<Block>();
529        let mut rng = StdRng::from_rng(&mut self.rng);
530
531        let mut base_rngs = mem::take(&mut self.base_rngs);
532        let owned_ots = mem::take(ots);
533        let mut jh = spawn_compute(move || {
534            let mut ots = owned_ots;
535            let t_mat_size = if S::MALICIOUS_SECURITY {
536                ots.len()
537            } else {
538                batch_size
539            };
540            let num_extra = (S::MALICIOUS_SECURITY as usize) * 128;
541            let mut t_mat = vec![Block::ZERO; t_mat_size];
542            let mut extra_t_mat = vec![Block::ZERO; num_extra];
543            let mut extra_messages: Vec<Block> = Vec::zeroed(num_extra);
544            let extra_choices = random_choices(num_extra, &mut rng);
545            let extra_choice_vec = choices_to_u8_vec(&extra_choices);
546
547            for (ots, choice_vec, extra) in [
548                (&mut ots[..], &choice_vec, false),
549                (&mut extra_messages[..], &extra_choice_vec, true),
550            ] {
551                for (chunk_idx, (output_chunk, choice_batch)) in ots
552                    .chunks_mut(batch_size)
553                    .zip(choice_vec.chunks(cols_byte_batch))
554                    .enumerate()
555                {
556                    let curr_batch_size = output_chunk.len();
557                    let chunk_t_mat = if S::MALICIOUS_SECURITY {
558                        if extra {
559                            &mut extra_t_mat
560                        } else {
561                            let offset = chunk_idx * batch_size;
562                            &mut t_mat[offset..offset + curr_batch_size]
563                        }
564                    } else {
565                        &mut t_mat[..curr_batch_size]
566                    };
567                    assert_eq!(output_chunk.len(), chunk_t_mat.len());
568                    assert_eq!(choice_batch.len() * 8, chunk_t_mat.len());
569                    let chunk_t_mat: &mut [u8] = bytemuck::must_cast_slice_mut(chunk_t_mat);
570                    // might change for last chunk
571                    let cols_byte_batch = choice_batch.len();
572                    for (row, [rng1, rng2]) in chunk_t_mat
573                        .chunks_exact_mut(cols_byte_batch)
574                        .zip(&mut base_rngs)
575                    {
576                        rng1.fill_bytes(row);
577                        let mut send_row = vec![0_u8; cols_byte_batch];
578                        rng2.fill_bytes(&mut send_row);
579                        // TODO wouldn't this be better on Blocks instead of u8?
580                        for ((v2, v1), choices) in send_row.iter_mut().zip(row).zip(choice_batch) {
581                            *v2 ^= *v1 ^ *choices;
582                        }
583                        ch_s.send(send_row)?;
584                    }
585                    let output_bytes = bytemuck::cast_slice_mut(output_chunk);
586                    transpose_bitmatrix(
587                        &chunk_t_mat[..BASE_OT_COUNT * cols_byte_batch],
588                        output_bytes,
589                        BASE_OT_COUNT,
590                    );
591                    if S::MALICIOUS_SECURITY {
592                        FIXED_KEY_HASH.tccr_hash_slice_mut(output_chunk, |i| {
593                            Block::from(chunk_idx * batch_size + i)
594                        });
595                    } else {
596                        FIXED_KEY_HASH.cr_hash_slice_mut(output_chunk);
597                    }
598                }
599            }
600
601            if S::MALICIOUS_SECURITY {
602                // dropping ch_s is important so the async task exits the ch_r loop
603                drop(ch_s);
604                let seed = kos_ch_r.recv()?;
605
606                let mut t1 = extra_t_mat;
607                let mut t2 = vec![Block::ZERO; BASE_OT_COUNT];
608
609                let mut x1 = Block::from_choices(&extra_choices);
610                let mut x2 = Block::ZERO;
611
612                let rng = AesRng::from_seed(seed);
613
614                let t_mat_ref = &t_mat;
615                let batches = count / batch_size;
616                let batch_sizes = iter::repeat_n(batch_size, batches)
617                    .chain((batch_size_remainder != 0).then_some(batch_size_remainder));
618
619                let choice_blocks: Vec<_> = choice_vec
620                    .chunks_exact(Block::BYTES)
621                    .map(|chunk| Block::try_from(chunk).expect("chunk is 16 bytes"))
622                    .collect();
623
624                let challenges: Vec<Block> = rng
625                    .sample_iter(StandardUniform)
626                    .take(choice_blocks.len())
627                    .collect();
628
629                for (x, s) in choice_blocks.iter().zip(challenges.iter()) {
630                    let (xi, xi2) = x.clmul(s);
631                    x1 ^= xi;
632                    x2 ^= xi2;
633                }
634
635                let block_batch_size = batch_size / BASE_OT_COUNT;
636
637                let challenge_iter =
638                    batch_sizes
639                        .clone()
640                        .enumerate()
641                        .flat_map(|(batch, curr_batch_size)| {
642                            challenges[batch * block_batch_size
643                                ..batch * block_batch_size + curr_batch_size / BASE_OT_COUNT]
644                                .iter()
645                                .cycle()
646                                .take(curr_batch_size)
647                        });
648                let t_idx_iter = batch_sizes.flat_map(|curr_batch_size| {
649                    (0..BASE_OT_COUNT).flat_map(move |t_idx| {
650                        iter::repeat_n(t_idx, curr_batch_size / BASE_OT_COUNT)
651                    })
652                });
653
654                for ((t, s), t_idx) in t_mat_ref.iter().zip(challenge_iter).zip(t_idx_iter) {
655                    let (ti, ti2) = t.clmul(s);
656                    t1[t_idx] ^= ti;
657                    t2[t_idx] ^= ti2;
658                }
659
660                for (t1i, t2i) in t1.iter_mut().zip(&mut t2) {
661                    *t1i = Block::gf_reduce(t1i, t2i);
662                }
663                t1.push(Block::gf_reduce(&x1, &x2));
664                kos_ch_s.send(t1)?;
665            }
666            Ok::<_, Error>((ots, base_rngs))
667        });
668
669        let (mut send, _) = sub_conn.byte_stream().await?;
670        while let Some(row) = ch_r.recv().await {
671            send.write_all(&row).await.map_err(Error::Communication)?;
672        }
673
674        if S::MALICIOUS_SECURITY {
675            // If the worker thread panics we break early from the above loop. We check for
676            // the panic to prevent a deadlock where we try to get the next message but the
677            // peer is still in the worker thread
678            let err = poll_fn(|cx| match jh.poll_unpin(cx) {
679                Poll::Ready(res) => Poll::Ready(res.map(drop)),
680                Poll::Pending => Poll::Ready(Ok(())),
681            })
682            .await;
683            if let Err(err) = err {
684                resume_unwind(err);
685            };
686            let (mut kos_send, mut kos_recv) = sub_conn.byte_stream().await?;
687
688            let seed = {
689                let mut kos_recv = kos_recv.as_stream::<Block>();
690                kos_recv.next().await.ok_or(Error::UnexcpectedClose)??
691            };
692
693            let success = 'success: {
694                if kos_ch_s_task.send(seed).is_err() {
695                    break 'success false;
696                }
697
698                let mut kos_send = kos_send.as_stream::<Vec<Block>>();
699                let Some(v) = kos_ch_r_task.recv().await else {
700                    break 'success false;
701                };
702                kos_send.send(v).await.map_err(Error::Communication)?;
703
704                true
705            };
706            if !success {
707                resume_unwind(jh.await.map(drop).expect_err("expected thread error"));
708            }
709        }
710
711        let (owned_ots, base_rngs) = match jh.await {
712            Ok(res) => res?,
713            Err(panicked) => resume_unwind(panicked),
714        };
715
716        self.base_rngs = base_rngs;
717        *ots = owned_ots;
718        Ok(())
719    }
720}
721
722impl<S: Security> CotSender for OtExtensionSender<S> {
723    type Error = Error;
724
725    async fn correlated_send_into<B, F>(
726        &mut self,
727        ots: &mut B,
728        correlation: F,
729    ) -> Result<(), Self::Error>
730    where
731        B: Buf<Block>,
732        F: FnMut(usize) -> Block + Send,
733    {
734        CorrelatedFromRandom::new(self)
735            .correlated_send_into(ots, correlation)
736            .await
737    }
738}
739
740impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
741    type Error = Error;
742
743    async fn correlated_receive_into<B>(
744        &mut self,
745        ots: &mut B,
746        choices: &[Choice],
747    ) -> Result<(), Self::Error>
748    where
749        B: Buf<Block>,
750    {
751        CorrelatedFromRandom::new(self)
752            .correlated_receive_into(ots, choices)
753            .await
754    }
755}
756
757fn choices_to_u8_vec(choices: &[Choice]) -> Vec<u8> {
758    assert_eq!(0, choices.len() % 8);
759    let mut v = vec![0_u8; choices.len() / 8];
760    for (chunk, byte) in choices.chunks_exact(8).zip(&mut v) {
761        for (i, choice) in chunk.iter().enumerate() {
762            *byte ^= choice.unwrap_u8() << i;
763        }
764    }
765    v
766}
767
768impl From<std::sync::mpsc::RecvError> for Error {
769    fn from(_: std::sync::mpsc::RecvError) -> Self {
770        Error::AsyncTaskDropped
771    }
772}
773
774impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
775    fn from(_: tokio::sync::mpsc::error::SendError<T>) -> Self {
776        Error::AsyncTaskDropped
777    }
778}
779
780#[cfg(test)]
781mod tests {
782
783    use cryprot_core::Block;
784    use cryprot_net::testing::{init_tracing, local_conn};
785    use rand::{SeedableRng, rngs::StdRng};
786
787    use crate::{
788        CotReceiver, CotSender, MaliciousMarker, RotReceiver, RotSender,
789        extension::{
790            DEFAULT_OT_BATCH_SIZE, OtExtensionReceiver, OtExtensionSender,
791            SemiHonestOtExtensionReceiver, SemiHonestOtExtensionSender,
792        },
793        random_choices,
794    };
795
796    #[tokio::test]
797    async fn test_extension() {
798        let _g = init_tracing();
799        const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE;
800        let (c1, c2) = local_conn().await.unwrap();
801        let rng1 = StdRng::seed_from_u64(42);
802        let mut rng2 = StdRng::seed_from_u64(24);
803        let choices = random_choices(COUNT, &mut rng2);
804        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
805        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
806        let (send_ots, recv_ots) =
807            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
808        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
809            assert_eq!(r, s[c.unwrap_u8() as usize]);
810        }
811    }
812
813    #[tokio::test]
814    async fn test_extension_half_batch() {
815        let _g = init_tracing();
816        const COUNT: usize = 2 * DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2;
817        let (c1, c2) = local_conn().await.unwrap();
818        let rng1 = StdRng::seed_from_u64(42);
819        let mut rng2 = StdRng::seed_from_u64(24);
820        let choices = random_choices(COUNT, &mut rng2);
821        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
822        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
823        let (send_ots, recv_ots) =
824            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
825        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
826            assert_eq!(r, s[c.unwrap_u8() as usize]);
827        }
828    }
829
830    #[tokio::test]
831    async fn test_extension_partial_batch() {
832        let _g = init_tracing();
833        const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2 + 128;
834        let (c1, c2) = local_conn().await.unwrap();
835        let rng1 = StdRng::seed_from_u64(42);
836        let mut rng2 = StdRng::seed_from_u64(24);
837        let choices = random_choices(COUNT, &mut rng2);
838        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
839        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
840        let (send_ots, recv_ots) =
841            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
842        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
843            assert_eq!(r, s[c.unwrap_u8() as usize]);
844        }
845    }
846
847    #[tokio::test]
848    async fn test_extension_malicious_half_batch() {
849        let _g = init_tracing();
850        const COUNT: usize = DEFAULT_OT_BATCH_SIZE / 2;
851        let (c1, c2) = local_conn().await.unwrap();
852        let rng1 = StdRng::seed_from_u64(42);
853        let mut rng2 = StdRng::seed_from_u64(24);
854        let choices = random_choices(COUNT, &mut rng2);
855        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
856        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
857
858        let (send_ots, recv_ots) =
859            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
860        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
861            assert_eq!(r, s[c.unwrap_u8() as usize]);
862        }
863    }
864
865    #[tokio::test]
866    async fn test_extension_malicious_partial_batch() {
867        let _g = init_tracing();
868        const COUNT: usize = DEFAULT_OT_BATCH_SIZE + DEFAULT_OT_BATCH_SIZE / 2 + 128;
869        let (c1, c2) = local_conn().await.unwrap();
870        let rng1 = StdRng::seed_from_u64(42);
871        let mut rng2 = StdRng::seed_from_u64(24);
872        let choices = random_choices(COUNT, &mut rng2);
873        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
874        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
875
876        let (send_ots, recv_ots) =
877            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
878        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
879            assert_eq!(r, s[c.unwrap_u8() as usize]);
880        }
881    }
882
883    #[tokio::test]
884    async fn test_extension_malicious_multiple_batch() {
885        let _g = init_tracing();
886        const COUNT: usize = DEFAULT_OT_BATCH_SIZE * 2;
887        let (c1, c2) = local_conn().await.unwrap();
888        let rng1 = StdRng::seed_from_u64(42);
889        let mut rng2 = StdRng::seed_from_u64(24);
890        let choices = random_choices(COUNT, &mut rng2);
891        let mut sender = OtExtensionSender::<MaliciousMarker>::new_with_rng(c1, rng1);
892        let mut receiver = OtExtensionReceiver::<MaliciousMarker>::new_with_rng(c2, rng2);
893
894        let (send_ots, recv_ots) =
895            tokio::try_join!(sender.send(COUNT), receiver.receive(&choices)).unwrap();
896        for ((r, s), c) in recv_ots.into_iter().zip(send_ots).zip(choices) {
897            assert_eq!(r, s[c.unwrap_u8() as usize]);
898        }
899    }
900
901    #[tokio::test]
902    async fn test_correlated_extension() {
903        let _g = init_tracing();
904        const COUNT: usize = 128;
905        let (c1, c2) = local_conn().await.unwrap();
906        let rng1 = StdRng::seed_from_u64(42);
907        let mut rng2 = StdRng::seed_from_u64(24);
908        let choices = random_choices(COUNT, &mut rng2);
909        let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
910        let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
911        let (send_ots, recv_ots) = tokio::try_join!(
912            sender.correlated_send(COUNT, |_| Block::ONES),
913            receiver.correlated_receive(&choices)
914        )
915        .unwrap();
916        for (i, ((r, s), c)) in recv_ots.into_iter().zip(send_ots).zip(choices).enumerate() {
917            if bool::from(c) {
918                assert_eq!(r ^ Block::ONES, s, "Block {i}");
919            } else {
920                assert_eq!(r, s, "Block {i}")
921            }
922        }
923    }
924}