use auto_ops::impl_op_ex;
use ck_meow::Meow;
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use crate::constants::SECURITY_PARAMETER;
pub const SEC_PARAM_64: usize = (SECURITY_PARAMETER + 64 - 1) / 64;
pub const SEC_PARAM_8: usize = (SECURITY_PARAMETER + 8 - 1) / 8;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
pub struct BitVector([u64; SEC_PARAM_64]);
impl BitVector {
pub fn zero() -> Self {
Self([0u64; SEC_PARAM_64])
}
pub fn random(rng: &mut impl CryptoRngCore) -> Self {
let mut out = [0u64; SEC_PARAM_64];
for o in &mut out {
*o = rng.next_u64();
}
Self(out)
}
#[inline(always)]
pub fn bit(&self, j: usize) -> u64 {
(self.0[j / 64] >> (j % 64)) & 1
}
pub fn from_bytes(bytes: &[u8; SEC_PARAM_8]) -> Self {
let u64s = bytes
.chunks_exact(8)
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()));
let mut out = [0u64; SEC_PARAM_64];
for (o, u) in out.iter_mut().zip(u64s) {
*o = u;
}
Self(out)
}
pub fn bytes(&self) -> [u8; SEC_PARAM_8] {
let mut out = [0u8; SEC_PARAM_8];
for (i, x_i) in self.0.iter().enumerate() {
out[8 * i..8 * (i + 1)].copy_from_slice(&x_i.to_le_bytes())
}
out
}
pub fn bits(&self) -> impl Iterator<Item = Choice> {
self.0
.into_iter()
.flat_map(|u| (0..64).map(move |j| ((u >> j) & 1).ct_eq(&1)))
}
pub fn xor_mut(&mut self, other: &Self) {
for (self_i, other_i) in self.0.iter_mut().zip(other.0.iter()) {
*self_i ^= other_i;
}
}
pub fn xor(&self, other: &Self) -> Self {
let mut out = *self;
out.xor_mut(other);
out
}
pub fn not(&self) -> Self {
let mut out = *self;
for out_i in &mut out.0 {
*out_i = !*out_i;
}
out
}
pub fn and_mut(&mut self, other: &Self) {
for (self_i, other_i) in self.0.iter_mut().zip(other.0.iter()) {
*self_i &= other_i;
}
}
pub fn and(&self, other: &Self) -> Self {
let mut out = *self;
out.and_mut(other);
out
}
pub fn gf_mul(&self, other: &Self) -> DoubleBitVector {
let mut out = [0u64; 2 * SEC_PARAM_64];
for k in (0..64).rev() {
for j in 0..SEC_PARAM_64 {
let to_add = Self::conditional_select(
&Self::zero(),
other,
Choice::from(((self.0[j] >> k) & 1) as u8),
);
for i in 0..SEC_PARAM_64 {
out[j + i] ^= to_add.0[i];
}
}
if k != 0 {
let mut prev = 0u64;
for out_i in &mut out {
let next_prev = *out_i >> 63;
*out_i = (*out_i << 1) | prev;
prev = next_prev;
}
}
}
DoubleBitVector(out)
}
}
impl ConditionallySelectable for BitVector {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut out = [0u64; SEC_PARAM_64];
for ((o_i, a_i), b_i) in out.iter_mut().zip(a.0.iter()).zip(b.0.iter()) {
*o_i = u64::conditional_select(a_i, b_i, choice);
}
Self(out)
}
}
impl_op_ex!(^ |u: &BitVector, v: &BitVector| -> BitVector { u.xor(v) });
impl_op_ex!(^= |u: &mut BitVector, v: &BitVector| { u.xor_mut(v) });
impl_op_ex!(&|u: &BitVector, v: &BitVector| -> BitVector { u.and(v) });
impl_op_ex!(&= |u: &mut BitVector, v: &BitVector| { u.and_mut(v) });
impl_op_ex!(!|u: &BitVector| -> BitVector { u.not() });
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct DoubleBitVector([u64; Self::SIZE]);
impl DoubleBitVector {
const SIZE: usize = 2 * SEC_PARAM_64;
pub fn zero() -> Self {
Self([0u64; Self::SIZE])
}
pub fn xor_mut(&mut self, other: &Self) {
for (self_i, other_i) in self.0.iter_mut().zip(other.0.iter()) {
*self_i ^= *other_i;
}
}
pub fn xor(&self, other: &Self) -> Self {
let mut out = *self;
out.xor_mut(other);
out
}
}
impl ConditionallySelectable for DoubleBitVector {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut out = [0u64; 2 * SEC_PARAM_64];
for ((o_i, a_i), b_i) in out.iter_mut().zip(a.0.iter()).zip(b.0.iter()) {
*o_i = u64::conditional_select(a_i, b_i, choice);
}
Self(out)
}
}
impl ConstantTimeEq for DoubleBitVector {
fn ct_eq(&self, other: &Self) -> Choice {
let mut out = Choice::from(1);
for (a, b) in self.0.iter().zip(other.0.iter()) {
out &= a.ct_eq(b);
}
out
}
}
impl_op_ex!(^ |u: &DoubleBitVector, v: &DoubleBitVector| -> DoubleBitVector { u.xor(v) });
impl_op_ex!(^= |u: &mut DoubleBitVector, v: &DoubleBitVector| { u.xor_mut(v) });
const PRG_CTX: &[u8] = b"cait-sith v0.8.0 correlated OT PRG";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct BitMatrix(Vec<BitVector>);
impl BitMatrix {
pub fn random(rng: &mut impl CryptoRngCore, height: usize) -> Self {
assert!(height % SECURITY_PARAMETER == 0);
Self((0..height).map(|_| BitVector::random(rng)).collect())
}
pub fn from_rows<'a>(rows: impl IntoIterator<Item = &'a BitVector>) -> Self {
Self(rows.into_iter().copied().collect())
}
pub fn height(&self) -> usize {
self.0.len()
}
pub fn rows(&self) -> impl Iterator<Item = &BitVector> {
self.0.iter()
}
pub fn column_chunks(&self, j: usize) -> impl Iterator<Item = BitVector> + '_ {
self.0.chunks_exact(SECURITY_PARAMETER).map(move |chunk| {
let mut out = BitVector::zero();
for (i, c_i) in chunk.iter().enumerate() {
out.0[i / 64] |= c_i.bit(j) << (i % 64);
}
out
})
}
pub fn xor_mut(&mut self, other: &Self) {
for (self_i, other_i) in self.0.iter_mut().zip(other.0.iter()) {
*self_i ^= other_i;
}
}
pub fn xor(&self, other: &Self) -> Self {
let mut out = self.clone();
out.xor_mut(other);
out
}
pub fn and_vec_mut(&mut self, v: &BitVector) {
for self_i in &mut self.0 {
*self_i &= v;
}
}
pub fn and_vec(&self, v: &BitVector) -> Self {
let mut out = self.clone();
out.and_vec_mut(v);
out
}
}
impl FromIterator<BitVector> for BitMatrix {
fn from_iter<T: IntoIterator<Item = BitVector>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
impl_op_ex!(^ |u: &BitMatrix, v: &BitMatrix| -> BitMatrix { u.xor(v) });
impl_op_ex!(^= |u: &mut BitMatrix, v: &BitMatrix| { u.xor_mut(v) });
impl_op_ex!(&|u: &BitMatrix, v: &BitVector| -> BitMatrix { u.and_vec(v) });
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct SquareBitMatrix {
pub matrix: BitMatrix,
}
impl TryFrom<BitMatrix> for SquareBitMatrix {
type Error = ();
fn try_from(matrix: BitMatrix) -> Result<Self, Self::Error> {
if matrix.height() != SECURITY_PARAMETER {
return Err(());
}
Ok(Self { matrix })
}
}
impl SquareBitMatrix {
pub fn expand_transpose(&self, sid: &[u8], rows: usize) -> BitMatrix {
assert!(rows % SECURITY_PARAMETER == 0);
let mut meow = Meow::new(PRG_CTX);
meow.meta_ad(b"sid", false);
meow.ad(sid, false);
let mut out = BitMatrix(vec![BitVector::zero(); rows]);
let row8 = (rows + 7) / 8;
for (j, row) in self.matrix.0.iter().enumerate() {
let mut expanded = vec![0u8; row8];
let mut meow = meow.clone();
meow.meta_ad(b"row", false);
meow.ad(b"", false);
for u in row.0 {
meow.ad(&u.to_le_bytes(), true);
}
meow.prf(&mut expanded, false);
for i in 0..rows {
out.0[i].0[j / 64] |= u64::from((expanded[i / 8] >> (i % 8)) & 1) << (j % 64);
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct ChoiceVector(Vec<BitVector>);
impl ChoiceVector {
pub fn random(rng: &mut impl CryptoRngCore, size: usize) -> Self {
assert!(size > 0 && size % SECURITY_PARAMETER == 0);
let data = (0..(size / SECURITY_PARAMETER))
.map(|_| BitVector::random(rng))
.collect();
Self(data)
}
pub fn bits(&self) -> impl Iterator<Item = Choice> + '_ {
self.0.iter().flat_map(|v| v.bits())
}
pub fn chunks(&self) -> impl Iterator<Item = &BitVector> {
self.0.iter()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_gf_multiplication() {
let a = BitVector([0b10, 0b10]);
let b = BitVector([0b100, 0b100]);
let c = DoubleBitVector([0b1000, 0, 0b1000, 0]);
assert_eq!(a.gf_mul(&b), c);
}
}