use ck_meow::Meow;
use elliptic_curve::{CurveArithmetic};
use magikitten::MeowRng;
use rand_core::{OsRng, RngCore};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use crate::{
compat::CSCurve,
constants::SECURITY_PARAMETER,
protocol::{
internal::{make_protocol, Context, PrivateChannel},
run_two_party_protocol, Participant, ProtocolError,
},
};
use super::{
bits::{BitMatrix, BitVector, ChoiceVector, DoubleBitVector, SquareBitMatrix},
correlated_ot_extension::{correlated_ot_receiver, correlated_ot_sender, CorrelatedOtParams},
};
const MEOW_CTX: &[u8] = b"Random OT Extension Hash";
fn hash_to_scalar<C: CSCurve>(i: usize, v: &BitVector) -> C::Scalar {
let mut meow = Meow::new(MEOW_CTX);
let i64 = u64::try_from(i).expect("failed to convert usize to u64");
meow.meta_ad(&i64.to_le_bytes(), false);
meow.ad(&v.bytes(), false);
let mut seed = [0u8; 32];
meow.prf(&mut seed, false);
C::sample_scalar_constant_time(&mut MeowRng::new(&seed))
}
fn adjust_size(size: usize) -> usize {
let r = size % SECURITY_PARAMETER;
let padded = if r == 0 {
size
} else {
size + (SECURITY_PARAMETER - r)
};
padded + 2 * SECURITY_PARAMETER
}
#[derive(Debug, Clone, Copy)]
pub struct RandomOtExtensionParams<'sid> {
pub sid: &'sid [u8],
pub batch_size: usize,
}
pub type RandomOTExtensionSenderOut<C> = Vec<(
<C as CurveArithmetic>::Scalar,
<C as CurveArithmetic>::Scalar,
)>;
pub type RandomOTExtensionReceiverOut<C> = Vec<(Choice, <C as CurveArithmetic>::Scalar)>;
pub async fn random_ot_extension_sender<C: CSCurve>(
mut chan: PrivateChannel,
params: RandomOtExtensionParams<'_>,
delta: BitVector,
k: &SquareBitMatrix,
) -> Result<RandomOTExtensionSenderOut<C>, ProtocolError> {
let adjusted_size = adjust_size(params.batch_size);
let q = correlated_ot_sender(
chan.child(0),
CorrelatedOtParams {
sid: params.sid,
batch_size: adjusted_size,
},
delta,
k,
)
.await?;
let mut seed = [0u8; 32];
OsRng.fill_bytes(&mut seed);
let wait0 = chan.next_waitpoint();
chan.send(wait0, &seed).await;
let mu = adjusted_size / SECURITY_PARAMETER;
let mut prng = MeowRng::new(&seed);
let chi: Vec<BitVector> = (0..mu).map(|_| BitVector::random(&mut prng)).collect();
let wait1 = chan.next_waitpoint();
let (small_x, small_t): (DoubleBitVector, Vec<DoubleBitVector>) = chan.recv(wait1).await?;
if small_t.len() != SECURITY_PARAMETER {
return Err(ProtocolError::AssertionFailed(
"small t of incorrect length".to_owned(),
));
}
for (j, small_t_j) in small_t.iter().enumerate() {
let delta_j = Choice::from(delta.bit(j) as u8);
let mut small_q_j = DoubleBitVector::zero();
for (q_i, chi_i) in q.column_chunks(j).zip(chi.iter()) {
small_q_j ^= q_i.gf_mul(chi_i);
}
let delta_j_x =
DoubleBitVector::conditional_select(&DoubleBitVector::zero(), &small_x, delta_j);
if !bool::from(small_q_j.ct_eq(&(small_t_j ^ delta_j_x))) {
return Err(ProtocolError::AssertionFailed("q check failed".to_owned()));
}
}
let mut out = Vec::with_capacity(params.batch_size);
for (i, q_i) in q.rows().take(params.batch_size).enumerate() {
let v0_i = hash_to_scalar::<C>(i, q_i);
let v1_i = hash_to_scalar::<C>(i, &(q_i ^ delta));
out.push((v0_i, v1_i))
}
Ok(out)
}
pub async fn random_ot_extension_receiver<C: CSCurve>(
mut chan: PrivateChannel,
params: RandomOtExtensionParams<'_>,
k0: &SquareBitMatrix,
k1: &SquareBitMatrix,
) -> Result<RandomOTExtensionReceiverOut<C>, ProtocolError> {
let adjusted_size = adjust_size(params.batch_size);
let b = ChoiceVector::random(&mut OsRng, adjusted_size);
let x: BitMatrix = b
.bits()
.map(|b_i| BitVector::conditional_select(&BitVector::zero(), &!BitVector::zero(), b_i))
.collect();
let t = correlated_ot_receiver(
chan.child(0),
CorrelatedOtParams {
sid: params.sid,
batch_size: adjusted_size,
},
k0,
k1,
&x,
)
.await;
let wait0 = chan.next_waitpoint();
let seed: [u8; 32] = chan.recv(wait0).await?;
let mu = adjusted_size / SECURITY_PARAMETER;
let mut prng = MeowRng::new(&seed);
let chi: Vec<BitVector> = (0..mu).map(|_| BitVector::random(&mut prng)).collect();
let mut small_x = DoubleBitVector::zero();
for (b_i, chi_i) in b.chunks().zip(chi.iter()) {
small_x.xor_mut(&b_i.gf_mul(chi_i));
}
let small_t: Vec<_> = (0..SECURITY_PARAMETER)
.map(|j| {
let mut small_t_j = DoubleBitVector::zero();
for (t_i, chi_i) in t.column_chunks(j).zip(chi.iter()) {
small_t_j ^= t_i.gf_mul(chi_i);
}
small_t_j
})
.collect();
let wait1 = chan.next_waitpoint();
chan.send(wait1, &(small_x, small_t)).await;
let out: Vec<_> = b
.bits()
.zip(t.rows())
.take(params.batch_size)
.enumerate()
.map(|(i, (b_i, t_i))| (b_i, hash_to_scalar::<C>(i, t_i)))
.collect();
Ok(out)
}
#[allow(dead_code)]
fn run_random_ot<C: CSCurve>(
(delta, k): (BitVector, &SquareBitMatrix),
(k0, k1): (&SquareBitMatrix, &SquareBitMatrix),
sid: &[u8],
batch_size: usize,
) -> Result<
(
RandomOTExtensionSenderOut<C>,
RandomOTExtensionReceiverOut<C>,
),
ProtocolError,
> {
let s = Participant::from(0u32);
let r = Participant::from(1u32);
let ctx_s = Context::new();
let ctx_r = Context::new();
let params = RandomOtExtensionParams { sid, batch_size };
run_two_party_protocol(
s,
r,
&mut make_protocol(
ctx_s.clone(),
random_ot_extension_sender::<C>(ctx_s.private_channel(s, r), params, delta, k),
),
&mut make_protocol(
ctx_r.clone(),
random_ot_extension_receiver::<C>(ctx_r.private_channel(r, s), params, k0, k1),
),
)
}
#[cfg(test)]
mod test {
use crate::triples::batch_random_ot::run_batch_random_ot;
use super::*;
use k256::{Scalar, Secp256k1};
#[test]
fn test_random_ot() -> Result<(), ProtocolError> {
let ((k0, k1), (delta, k)) = run_batch_random_ot::<Secp256k1>()?;
let batch_size = 16;
let (sender_out, receiver_out) =
run_random_ot::<Secp256k1>((delta, &k), (&k0, &k1), b"test sid", batch_size)?;
assert_eq!(sender_out.len(), batch_size);
assert_eq!(receiver_out.len(), batch_size);
for ((v0_i, v1_i), (b_i, vb_i)) in sender_out.iter().zip(receiver_out.iter()) {
assert_eq!(*vb_i, Scalar::conditional_select(v0_i, v1_i, *b_i));
}
Ok(())
}
}