mod matrix;
mod regev;
use matrix::{a_matrix_mul_db, mat_vec_mul, packed_mat_vec_mul};
pub use matrix::{Matrix, Vector};
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
use rand_distr::Uniform;
use regev::{encrypt, gen_a_matrix, gen_secret_key};
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd)]
pub struct Database {
pub data: Matrix,
pub modulus: u64,
}
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("The number of rows and columns in the matrix must be equal")]
NonSquareMatrixError,
#[error("The modulus of the database must be less than 21")]
CompressionModulusError,
}
impl Database {
pub fn from_matrix(data: Matrix, mod_power: u8) -> Result<Database, DatabaseError> {
let mod_power = if mod_power > 32 { 32 } else { mod_power };
if data.nrows() != data.ncols() {
return Err(DatabaseError::NonSquareMatrixError);
}
let modulus = 2_u64.pow(mod_power as u32);
Ok(Database { data, modulus })
}
pub fn new_random(side_len: usize, mod_power: u8) -> Database {
let mod_power = if mod_power > 32 { 32 } else { mod_power };
let modulus = 2_u64.pow(mod_power as u32);
let range = 0..=modulus - 1;
Database {
data: Matrix::new_random(side_len, side_len, range, None),
modulus,
}
}
pub fn new_random_seed(side_len: usize, mod_power: u8, seed: u64) -> Database {
let mod_power = if mod_power > 32 { 32 } else { mod_power };
let modulus = 2_u64.pow(mod_power as u32);
let range = 0..=modulus - 1;
Database {
data: Matrix::new_random(side_len, side_len, range, Some(seed)),
modulus,
}
}
pub fn from_vector(data: Vec<u64>, mod_power: u8) -> Database {
let mod_power = if mod_power > 32 { 32 } else { mod_power };
let modulus = 2_u64.pow(mod_power as u32);
let db_side_len = (data.len() as f32).sqrt().ceil() as usize;
Database {
data: Matrix::from_vec(data, db_side_len, db_side_len),
modulus,
}
}
pub fn zeros(side_len: usize, mod_power: Option<u8>) -> Database {
let mod_power = if let Some(num) = mod_power { num } else { 1 };
let mod_power = if mod_power > 32 { 32 } else { mod_power };
let modulus = 2_u64.pow(mod_power as u32);
Database {
data: Matrix::zeros(side_len, side_len),
modulus,
}
}
pub fn side_len(&self) -> usize {
self.data.nrows()
}
pub fn get(&self, index: usize) -> Option<u64> {
let row_index = index / self.data.nrows();
let col_index = index % self.data.ncols();
self.data.get(row_index, col_index)
}
pub fn compress(&self) -> Result<CompressedDatabase, DatabaseError> {
if self.modulus > 2_u64.pow(21) {
return Err(DatabaseError::CompressionModulusError);
}
let mod_power = self.modulus.ilog2();
let mask = self.modulus - 1;
let data: Vec<u64> = self
.data
.data
.iter()
.map(move |row| {
(0..row.len().div_ceil(3)).map(move |i| {
row.get(i * 3).unwrap_or(&0) & mask
| (row.get(i * 3 + 1).unwrap_or(&0) & mask) << mod_power
| (row.get(i * 3 + 2).unwrap_or(&0) & mask) << (mod_power * 2)
})
})
.flatten()
.collect();
Ok(CompressedDatabase {
data: Matrix::from_vec(data, self.data.nrows(), self.data.ncols().div_ceil(3)),
nrows: self.data.nrows(),
ncols: self.data.ncols().div_ceil(3),
mod_power,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd)]
pub struct CompressedDatabase {
data: Matrix,
nrows: usize,
ncols: usize,
mod_power: u32,
}
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)]
pub struct ClientState {
row_index: usize,
column_index: usize,
a_matrix_seed: u64,
db_side_len: usize,
secret_key: Vector,
secret_dimension: usize,
}
impl ClientState {
fn new(
row_index: usize,
column_index: usize,
a_matrix_seed: u64,
db_side_len: usize,
secret_key: &Vector,
) -> ClientState {
ClientState {
row_index,
column_index,
a_matrix_seed,
db_side_len,
secret_key: secret_key.clone(),
secret_dimension: secret_key.len(),
}
}
}
pub fn setup(database: &Database, secret_dimension: usize) -> (u64, Matrix) {
let mut rng = ChaCha20Rng::from_entropy();
let server_hint = Uniform::from(0..=u64::MAX).sample(&mut rng);
let data = database.data.add_scalar(u64::MAX * (database.modulus / 2));
let a_matrix = gen_a_matrix(database.side_len(), secret_dimension, Some(server_hint));
let client_hint = a_matrix_mul_db(&a_matrix, &data);
(server_hint, client_hint)
}
pub fn setup_seeded(database: &Database, secret_dimension: usize, seed: [u8; 32]) -> (u64, Matrix) {
let mut rng = ChaCha20Rng::from_seed(seed);
let server_hint = Uniform::from(0..=u64::MAX).sample(&mut rng);
let data = database.data.add_scalar(u64::MAX * (database.modulus / 2));
let a_matrix = gen_a_matrix(database.side_len(), secret_dimension, Some(server_hint));
let client_hint = a_matrix_mul_db(&a_matrix, &data);
(server_hint, client_hint)
}
pub fn query(
index: usize,
db_side_len: usize,
secret_dimension: usize,
a_matrix_seed: u64,
plain_mod: u64,
) -> (ClientState, Vector) {
let secret_key = gen_secret_key(secret_dimension, None);
let a_matrix = gen_a_matrix(db_side_len, secret_dimension, Some(a_matrix_seed));
let row_index = index % db_side_len;
let column_index = index / db_side_len;
let mut query_vector = Vector::zeros(db_side_len);
query_vector.data[row_index] = 1;
let client_state = ClientState::new(
row_index,
column_index,
a_matrix_seed,
db_side_len,
&secret_key,
);
(
client_state,
encrypt(&secret_key, &a_matrix, &query_vector, plain_mod).1,
)
}
pub fn answer(database: &CompressedDatabase, query_cipher: &Vector) -> Vector {
packed_mat_vec_mul(&query_cipher, &database.data, database.mod_power)
}
pub fn answer_uncompressed(database: &Database, query_cipher: &Vector) -> Vector {
mat_vec_mul(&query_cipher, &database.data)
}
pub fn recover(
client_state: &ClientState,
client_hint: &Matrix,
answer_cipher: &Vector,
query_cipher: &Vector,
plaintext_mod: u64,
) -> u64 {
let ciphertext_mod = 2u128.pow(64);
let q_over_p = (ciphertext_mod / plaintext_mod as u128) as u64;
let secret_key = &client_state.secret_key;
let column_index = client_state.column_index;
let ratio = plaintext_mod / 2;
let noised = answer_cipher.get_unchecked(column_index)
- ratio * query_cipher.sum()
- Vector::from_vec(client_hint.row_unchecked(column_index)).dot(secret_key);
let denoised = (noised + q_over_p / 2) / q_over_p;
(denoised - ratio).rem_euclid(plaintext_mod)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test1() {
const SEED: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
];
let secret_dimension = 2048;
let db_side_len = 40;
let mod_power = 3;
let plain_mod = 2_u64.pow(mod_power as u32);
let index = 0;
let database = Database::new_random_seed(db_side_len, mod_power, 42);
let compressed_db = database.compress().unwrap();
let (server_hint, client_hint) = setup_seeded(&database, secret_dimension, SEED);
let (client_state, query_cipher) =
query(index, db_side_len, secret_dimension, server_hint, plain_mod);
let answer_cipher = answer(&compressed_db, &query_cipher);
let record = recover(
&client_state,
&client_hint,
&answer_cipher,
&query_cipher,
plain_mod,
);
assert_eq!(record, database.get(index).unwrap())
}
#[test]
fn test2() {
const SEED: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
];
let secret_dimension = 2048;
let db_side_len = 1000;
let mod_power = 17;
let plain_mod = 2u64.pow(mod_power as u32);
let database = Database::new_random_seed(db_side_len, mod_power, 42);
let compressed_database = database.compress().unwrap();
let (server_hint, client_hint) = setup_seeded(&database, secret_dimension, SEED);
for index in 0..100 {
let (client_state, query_cipher) =
query(index, db_side_len, secret_dimension, server_hint, plain_mod);
let answer_cipher = answer(&compressed_database, &query_cipher);
let record = recover(
&client_state,
&client_hint,
&answer_cipher,
&query_cipher,
plain_mod,
);
assert_eq!(record, database.get(index).unwrap())
}
}
#[test]
fn test3() {
const SEED: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
];
let secret_dimension = 10;
let db_side_len = 1000;
let mod_power = 3;
let plain_mod = 2u64.pow(mod_power as u32);
let database = Database::new_random_seed(db_side_len, mod_power, 42);
println!("Database Modulus: {}", database.modulus);
let (server_hint, client_hint) = setup_seeded(&database, secret_dimension, SEED);
for index in 0..100 {
let (client_state, query_cipher) =
query(index, db_side_len, secret_dimension, server_hint, plain_mod);
let answer_cipher = answer_uncompressed(&database, &query_cipher);
let record = recover(
&client_state,
&client_hint,
&answer_cipher,
&query_cipher,
plain_mod,
);
assert_eq!(record, database.get(index).unwrap())
}
}
#[test]
fn test4() {
const SEED: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
];
let secret_dimension = 1000;
let db_side_len = 38;
let mod_power = 17;
let plain_mod = 2u64.pow(mod_power as u32);
let database = Database::new_random_seed(db_side_len, mod_power, 42);
let database_compressed = database.compress().unwrap();
let (server_hint, client_hint) = setup_seeded(&database, secret_dimension, SEED);
for index in 0..1444 {
let (client_state, query_cipher) =
query(index, db_side_len, secret_dimension, server_hint, plain_mod);
let answer_cipher = answer(&database_compressed, &query_cipher);
let record = recover(
&client_state,
&client_hint,
&answer_cipher,
&query_cipher,
plain_mod,
);
assert_eq!(record, database.get(index).unwrap())
}
}
}