use super::tests_unsigned::{
nb_tests_for_params, nb_tests_smaller_for_params, overflowing_add_under_modulus,
overflowing_mul_under_modulus, overflowing_sub_under_modulus, random_non_zero_value,
MAX_NB_CTXT,
};
use crate::integer::block_decomposition::BlockDecomposer;
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::{
IntegerKeyKind, IntegerRadixCiphertext, RadixCiphertext, RadixClientKey, ServerKey,
};
use crate::shortint::parameters::*;
use rand::Rng;
use std::sync::Arc;
#[cfg(not(tarpaulin))]
pub(crate) const NB_CTXT: usize = 4;
#[cfg(tarpaulin)]
pub(crate) const NB_CTXT: usize = 2;
pub(crate) trait FunctionExecutor<TestInput, TestOutput> {
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>);
fn execute(&mut self, input: TestInput) -> TestOutput;
}
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_add::unchecked_add_test;
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_add::{
default_add_test, unchecked_add_assign_test,
};
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_aes::{
aes_dynamic_parallelism_many_inputs_test, aes_fixed_parallelism_1_input_test,
aes_fixed_parallelism_2_inputs_test,
};
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_aes256::{
aes_256_dynamic_parallelism_many_inputs_test, aes_256_fixed_parallelism_1_input_test,
aes_256_fixed_parallelism_2_inputs_test,
};
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::default_neg_test;
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::unchecked_neg_test;
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::default_sub_test;
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::unchecked_sub_test;
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_sum::default_sum_ciphertexts_vec_test;
#[cfg(feature = "gpu")]
pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_vector_find::{
default_contains_clear_test_case, default_contains_test_case,
default_first_index_in_clears_test_case, default_first_index_of_clear_test_case,
default_first_index_of_test_case, default_index_in_clears_test_case,
default_index_of_clear_test_case, default_index_of_test_case, default_is_in_clears_test_case,
default_match_value_or_test_case, default_match_value_test_case,
unchecked_contains_clear_test_case, unchecked_contains_test_case,
unchecked_first_index_in_clears_test_case, unchecked_first_index_of_clear_test_case,
unchecked_first_index_of_test_case, unchecked_index_in_clears_test_case,
unchecked_index_of_clear_test_case, unchecked_index_of_test_case,
unchecked_is_in_clears_test_case, unchecked_match_value_or_test_case,
unchecked_match_value_test_case,
};
use crate::shortint::server_key::CiphertextNoiseDegree;
pub(crate) fn unchecked_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let encrypted_result = executor.execute((&ctxt_0, &ctxt_1));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected_result = clear_0.wrapping_mul(clear_1) % modulus;
assert_eq!(decrypted_result, expected_result);
}
}
pub(crate) fn unchecked_block_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a crate::shortint::Ciphertext, usize),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let block_modulus = cks.parameters().message_modulus().0;
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % block_modulus;
let index = rng.gen_range(0..=(NB_CTXT - 1) as u32);
let multiplier = cks.parameters().message_modulus().0.pow(index) as u64;
let index = index as usize;
let ct_zero = cks.encrypt(clear_0);
let ct_one = cks.encrypt_one_block(clear_1);
let ct_res = executor.execute((&ct_zero, &ct_one, index));
let dec_res: u64 = cks.decrypt(&ct_res);
let expected = clear_0.wrapping_mul(clear_1).wrapping_mul(multiplier) % modulus;
assert_eq!(expected, dec_res);
}
}
pub(crate) fn unchecked_mul_corner_cases_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let nb_ct =
(128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize;
let cks = RadixClientKey::from((cks, nb_ct));
executor.setup(&cks, sks);
{
let clear = 307096569525960547621731375222677666984u128;
let scalar = 5207034748027904122u64;
let ct = cks.encrypt(clear);
let ct2 = cks.encrypt(scalar);
let ct_res = executor.execute((&ct, &ct2));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar as u128), dec_res);
}
{
let clear = u128::MAX;
let scalar = u128::MAX;
let ct = cks.encrypt(clear);
let ct2 = cks.encrypt(scalar);
let ct_res = executor.execute((&ct, &ct2));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar), dec_res);
}
}
pub(crate) fn unchecked_scalar_add_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let encrypted_result = executor.execute((&ctxt_0, clear_1));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected_result = clear_0.wrapping_add(clear_1) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid add result, expected {clear_0} + {clear_1} \
to be {expected_result}, but got {decrypted_result}."
);
}
}
pub(crate) fn unchecked_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let encrypted_result = executor.execute((&ctxt_0, clear_1));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected_result = clear_0.wrapping_sub(clear_1) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid sub result, expected {clear_0} - {clear_1} \
to be {expected_result}, but got {decrypted_result}."
);
}
}
pub(crate) fn unchecked_scalar_mul_corner_cases_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let nb_ct =
(128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize;
let cks = RadixClientKey::from((cks, nb_ct));
executor.setup(&cks, sks.clone());
{
let clear = 307096569525960547621731375222677666984u128;
let scalar = 5207034748027904122u64;
let ct = cks.encrypt(clear);
let ct_res = executor.execute((&ct, scalar));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar as u128), dec_res);
let clear = u128::MAX;
let scalar = u64::MAX;
let ct = cks.encrypt(clear);
let ct_res = executor.execute((&ct, scalar));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar as u128), dec_res);
}
{
let cks: crate::integer::ClientKey = cks.into();
let nb_ct =
(8f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize;
let cks = RadixClientKey::from((cks, nb_ct));
executor.setup(&cks, sks);
let clear = 123u64;
let scalar = 17823812983255694336u64;
assert_eq!(scalar % 256, 0);
let ct = cks.encrypt(clear);
let ct_res = executor.execute((&ct, scalar));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar) % 256, dec_res);
}
}
pub(crate) fn unchecked_scalar_left_shift_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let nb_bits = modulus.ilog2();
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u32>();
let ct = cks.encrypt(clear);
{
let scalar = scalar % nb_bits;
let encrypted_result = executor.execute((&ct, scalar as u64));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = (clear << u64::from(scalar)) % modulus;
assert_eq!(
expected, decrypted_result,
"Invalid left shift result for {clear} << {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
{
let scalar = scalar.saturating_add(nb_bits);
let encrypted_result = executor.execute((&ct, scalar as u64));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = (clear << u64::from(scalar % nb_bits)) % modulus;
assert_eq!(
expected, decrypted_result,
"Invalid left shift result for {clear} << {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
}
let clear = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2();
for scalar in 0..nb_bits_in_block {
let encrypted_result = executor.execute((&ct, scalar as u64));
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = (clear << u64::from(scalar)) % modulus;
assert_eq!(
expected, decrypted_result,
"Invalid left shift result for {clear} << {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
}
pub(crate) fn unchecked_scalar_right_shift_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let nb_bits = modulus.ilog2();
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u32>();
let ct = cks.encrypt(clear);
{
let scalar = scalar % nb_bits;
let encrypted_result = executor.execute((&ct, scalar as u64));
assert!(encrypted_result.block_carries_are_empty());
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = clear >> u64::from(scalar);
assert_eq!(
expected, decrypted_result,
"Invalid right shift result for {clear} >> {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
{
let scalar = scalar.saturating_add(nb_bits);
let encrypted_result = executor.execute((&ct, scalar as u64));
assert!(encrypted_result.block_carries_are_empty());
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = clear >> u64::from(scalar % nb_bits);
assert_eq!(
expected, decrypted_result,
"Invalid right shift result for {clear} >> {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
}
let clear = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2();
for scalar in 0..nb_bits_in_block {
let encrypted_result = executor.execute((&ct, scalar as u64));
assert!(encrypted_result.block_carries_are_empty());
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected = clear >> u64::from(scalar);
assert_eq!(
expected, decrypted_result,
"Invalid right shift result for {clear} >> {scalar}: \
expected {expected}, got {decrypted_result}"
);
}
}
pub(crate) fn smart_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a mut RadixCiphertext, &'a mut RadixCiphertext),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
let ctxt_1 = cks.encrypt(clear1);
let mut ctxt_2 = cks.encrypt(clear2);
let mut res = ctxt_1.clone();
let mut clear = clear1;
res = executor.execute((&mut res, &mut ctxt_2));
for _ in 0..nb_tests_smaller {
res = executor.execute((&mut res, &mut ctxt_2));
clear = (clear * clear2) % modulus;
}
let dec: u64 = cks.decrypt(&res);
clear = (clear * clear2) % modulus;
assert_eq!(clear, dec);
}
}
pub(crate) fn smart_block_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(
&'a mut RadixCiphertext,
&'a mut crate::shortint::Ciphertext,
usize,
),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let block_modulus = cks.parameters().message_modulus().0;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % block_modulus;
let ctxt_1 = cks.encrypt(clear1);
let mut ctxt_2 = cks.encrypt_one_block(clear2);
let mut res = ctxt_1.clone();
let mut clear = clear1;
let index = rng.gen_range(0..=(NB_CTXT - 1) as u32);
let multiplier = cks.parameters().message_modulus().0.pow(index) as u64;
let index = index as usize;
res = executor.execute((&mut res, &mut ctxt_2, index));
clear = (clear.wrapping_mul(clear2.wrapping_mul(multiplier))) % modulus;
for _ in 0..nb_tests_smaller {
res = executor.execute((&mut res, &mut ctxt_2, index));
clear = (clear.wrapping_mul(clear2.wrapping_mul(multiplier))) % modulus;
}
let dec: u64 = cks.decrypt(&res);
assert_eq!(clear, dec);
}
}
pub(crate) fn smart_bitand_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a mut RadixCiphertext, &'a mut RadixCiphertext),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&mut ctxt_0, &mut ctxt_1));
clear = clear_0 & clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let mut ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&mut ct_res, &mut ctxt_2));
clear &= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn smart_bitor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a mut RadixCiphertext, &'a mut RadixCiphertext),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&mut ctxt_0, &mut ctxt_1));
clear = (clear_0 | clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let mut ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&mut ct_res, &mut ctxt_2));
clear = (clear | clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn smart_bitxor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a mut RadixCiphertext, &'a mut RadixCiphertext),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&mut ctxt_0, &mut ctxt_1));
clear = (clear_0 ^ clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let mut ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&mut ct_res, &mut ctxt_2));
clear = (clear ^ clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn smart_scalar_add_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let mut clear;
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ct_res = executor.execute((&mut ctxt_0, clear_1));
clear = (clear_0 + clear_1) % modulus;
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&mut ct_res, clear_1));
clear = (clear + clear_1) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn smart_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let mut clear;
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ct_res = executor.execute((&mut ctxt_0, clear_1));
clear = clear_0.wrapping_sub(clear_1) % modulus;
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&mut ct_res, clear_1));
clear = clear.wrapping_sub(clear_1) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn smart_scalar_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u64>() % modulus;
let mut ct = cks.encrypt(clear);
let ct_res = executor.execute((&mut ct, scalar));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!((clear * scalar) % modulus, dec_res);
}
}
pub(crate) fn smart_scalar_mul_u128_fix_non_reg_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let nb_ct =
(128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize;
let cks = RadixClientKey::from((cks, nb_ct));
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
let clear = rng.gen::<u128>();
let scalar = rng.gen::<u64>();
let mut ct = cks.encrypt(clear);
let ct_res = executor.execute((&mut ct, scalar));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_mul(scalar as u128), dec_res);
}
pub(crate) fn default_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks.clone());
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
let ctxt_1 = cks.encrypt(clear1);
let ctxt_2 = cks.encrypt(clear2);
let mut res = ctxt_1.clone();
let mut clear = clear1;
res = executor.execute((&res, &ctxt_2));
assert!(res.block_carries_are_empty());
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&res, &ctxt_2));
res = executor.execute((&res, &ctxt_2));
assert!(res.block_carries_are_empty());
assert_eq!(res, tmp);
assert_eq!(res, tmp, "Failed determinism check, \n\n\n msg0: {clear1}, msg1: {clear2}, \n\n\nctxt0: {ctxt_1:?}, \n\n\nctxt1: {ctxt_2:?}\n\n\n");
clear = (clear * clear2) % modulus;
}
let dec: u64 = cks.decrypt(&res);
clear = (clear * clear2) % modulus;
assert_eq!(clear, dec);
}
{
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen_range(0u64..=1);
let ctxt_1 = cks.encrypt(clear1);
let ctxt_2: RadixCiphertext = sks.create_trivial_radix(clear2, ctxt_1.blocks.len());
assert!(ctxt_2.holds_boolean_value());
let res = executor.execute((&ctxt_1, &ctxt_2));
let dec: u64 = cks.decrypt(&res);
assert_eq!(dec, clear1 * clear2);
let res = executor.execute((&ctxt_2, &ctxt_1));
let dec: u64 = cks.decrypt(&res);
assert_eq!(dec, clear1 * clear2);
}
}
pub(crate) fn default_overflowing_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks.clone());
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nctxt1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nctxt1: {ctxt_1:?}\n\n\n");
let (expected_result, expected_overflowed) =
overflowing_mul_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for mul, for ({clear_0} * {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_mul for ({clear_0} * {clear_1}) % {modulus}
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..nb_tests_smaller {
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);
let clear_lhs = clear_0.wrapping_add(clear_2) % modulus;
let clear_rhs = clear_1.wrapping_add(clear_3) % modulus;
let d0: u64 = cks.decrypt(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: u64 = cks.decrypt(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
overflowing_mul_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for mul, for ({clear_lhs} * {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_mul, for ({clear_lhs} -{clear_rhs}) % {modulus}
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
let values = [
(rng.gen::<u64>() % modulus, rng.gen::<u64>() % modulus),
(rng.gen::<u64>() % modulus, rng.gen::<u64>() % modulus),
(rng.gen::<u64>() % modulus, rng.gen::<u64>() % modulus),
(rng.gen::<u64>() % modulus, rng.gen::<u64>() % modulus),
(rng.gen::<u64>() % modulus, 0),
(0, rng.gen::<u64>() % modulus),
];
for (clear_0, clear_1) in values {
let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (expected_result, expected_overflowed) =
overflowing_mul_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for mul, for ({clear_0} * {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_mul, for ({clear_0} {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
pub(crate) fn unchecked_bitnot_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let ctxt = cks.encrypt(clear);
let ct_res = executor.execute(&ctxt);
let dec: u64 = cks.decrypt(&ct_res);
let clear_result = !clear % modulus;
assert_eq!(clear_result, dec);
}
}
pub(crate) fn unchecked_bitand_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
clear = clear_0 & clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&ct_res, &ctxt_2));
clear &= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn unchecked_bitor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
clear = (clear_0 | clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&ct_res, &ctxt_2));
clear |= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn unchecked_bitxor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
clear = clear_0 ^ clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
ct_res = executor.execute((&ct_res, &ctxt_2));
clear = (clear ^ clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_bitand_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
clear = clear_0 & clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
let tmp = executor.execute((&ct_res, &ctxt_2));
ct_res = executor.execute((&ct_res, &ctxt_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear &= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_bitor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
clear = (clear_0 | clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
let tmp = executor.execute((&ct_res, &ctxt_2));
ct_res = executor.execute((&ct_res, &ctxt_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear |= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_bitxor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
let mut clear;
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
clear = clear_0 ^ clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
let tmp = executor.execute((&ct_res, &ctxt_2));
ct_res = executor.execute((&ct_res, &ctxt_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear = (clear ^ clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_bitnot_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let ctxt = cks.encrypt(clear);
let tmp = executor.execute(&ctxt);
let ct_res = executor.execute(&ctxt);
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec: u64 = cks.decrypt(&ct_res);
let clear_result = !clear % modulus;
assert_eq!(clear_result, dec);
}
}
pub(crate) fn default_scalar_add_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
let cks: crate::integer::ClientKey = cks.into();
let mut clear;
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks);
let mut ct_res = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
clear = (clear_0 + clear_1) % modulus;
let dec_res: u64 = cks.decrypt_radix(&ct_res);
assert_eq!(
clear, dec_res,
"invalid result for ({clear_0} + {clear_1}) % {modulus} (num_blocks: {num_blocks})"
);
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&ct_res, clear_1));
ct_res = executor.execute((&ct_res, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear = clear.wrapping_add(clear_1) % modulus;
let dec_res: u64 = cks.decrypt_radix(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
}
pub(crate) fn default_overflowing_scalar_add_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks.clone());
let cks: crate::integer::ClientKey = cks.into();
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nclear1: {clear_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nclear1: {clear_1:?}\n\n\n");
let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..nb_tests_smaller {
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_rhs = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let (clear_lhs, _) = overflowing_add_under_modulus(clear_0, clear_2, modulus);
let d0: u64 = cks.decrypt_radix(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed, expected_overflowed,
"Invalid overflow flag result for overflowing_add, \
for ({clear_lhs} + {clear_rhs}) % {modulus} ({num_blocks} blocks) \n\
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let a: RadixCiphertext = sks.create_trivial_radix(clear_0, num_blocks);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed, expected_overflowed,
"Invalid overflow flag result for overflowing_add, \
for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \n\
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
#[cfg(not(feature = "gpu"))]
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
for _ in 0..2 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen_range(modulus..=u64::MAX);
let a: RadixCiphertext = cks.encrypt_radix(clear_0, num_blocks);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed, expected_overflowed,
"Invalid overflow flag result for overflowing_add, \
for ({clear_0} + {clear_1}) % {modulus} ({num_blocks} blocks) \n\
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert!(decrypted_overflowed); assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
}
pub(crate) fn default_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
let cks: crate::integer::ClientKey = cks.into();
let mut clear;
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks);
let mut ct_res = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
clear = (clear_0.wrapping_sub(clear_1)) % modulus;
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&ct_res, clear_1));
ct_res = executor.execute((&ct_res, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear = (clear.wrapping_sub(clear_1)) % modulus;
let dec_res: u64 = cks.decrypt_radix(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
}
pub(crate) fn default_overflowing_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks.clone());
let cks: crate::integer::ClientKey = cks.into();
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nclear1: {clear_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nclear1: {clear_1:?}\n\n\n");
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..nb_tests_smaller {
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_rhs = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let (clear_lhs, _) = overflowing_add_under_modulus(clear_0, clear_2, modulus);
let d0: u64 = cks.decrypt_radix(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let a: RadixCiphertext = sks.create_trivial_radix(clear_0, num_blocks);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
for _ in 0..2 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen_range(modulus..=u64::MAX);
let a: RadixCiphertext = cks.encrypt_radix(clear_0, num_blocks);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert!(decrypted_overflowed); assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
}
pub(crate) fn default_scalar_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
let ct_res = executor.execute((&ct, scalar));
let tmp = executor.execute((&ct, scalar));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!((clear * scalar) % modulus, dec_res);
}
}
pub(crate) fn default_default_block_mul_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a crate::shortint::Ciphertext, usize),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let block_modulus = cks.parameters().message_modulus().0;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % block_modulus;
let ctxt_1 = cks.encrypt(clear1);
let ctxt_2 = cks.encrypt_one_block(clear2);
let index = rng.gen_range(0..=(NB_CTXT - 1) as u32);
let multiplier = cks.parameters().message_modulus().0.pow(index) as u64;
let index = index as usize;
let mut res = ctxt_1.clone();
let mut clear = clear1;
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&res, &ctxt_2, index));
res = executor.execute((&res, &ctxt_2, index));
assert!(res.block_carries_are_empty());
assert!(res
.blocks
.iter()
.all(|b| b.noise_level() <= NoiseLevel::NOMINAL));
assert_eq!(res, tmp);
clear = clear.wrapping_mul(clear2.wrapping_mul(multiplier)) % modulus;
let dec: u64 = cks.decrypt(&res);
assert_eq!(clear, dec);
}
}
}
pub(crate) fn default_scalar_mul_u128_fix_non_reg_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let nb_ct =
(128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize;
let cks = RadixClientKey::from((cks, nb_ct));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let mut rng = rand::thread_rng();
let clear = rng.gen::<u128>();
let scalar = rng.gen::<u64>();
let ct = cks.encrypt(clear);
let ct_res = executor.execute((&ct, scalar));
let dec_res: u128 = cks.decrypt(&ct_res);
assert_eq!(
clear.wrapping_mul(scalar as u128),
dec_res,
"Invalid result {clear} * {scalar}"
);
}
pub(crate) fn default_scalar_bitand_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let mut clear;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ct_res = executor.execute((&ctxt_0, 1));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear_0 & 1, dec_res);
let mut ct_res = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
clear = clear_0 & clear_1;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let tmp = executor.execute((&ct_res, clear_2));
ct_res = executor.execute((&ct_res, clear_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear &= clear_2;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_scalar_bitor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let mut clear;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ct_res = executor.execute((&ctxt_0, 1));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear_0 | 1, dec_res);
let mut ct_res = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
clear = (clear_0 | clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let tmp = executor.execute((&ct_res, clear_2));
ct_res = executor.execute((&ct_res, clear_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear = (clear | clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_scalar_bitxor_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let mut clear;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ct_res = executor.execute((&ctxt_0, 1));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear_0 ^ 1, dec_res);
let mut ct_res = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
clear = (clear_0 ^ clear_1) % modulus;
for _ in 0..nb_tests_smaller {
let clear_2 = rng.gen::<u64>() % modulus;
let tmp = executor.execute((&ct_res, clear_2));
ct_res = executor.execute((&ct_res, clear_2));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
clear = (clear ^ clear_2) % modulus;
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_scalar_left_shift_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let nb_bits = modulus.ilog2();
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u32>();
let ct = cks.encrypt(clear);
{
let scalar = scalar % nb_bits;
let ct_res = executor.execute((&ct, scalar as u64));
let tmp = executor.execute((&ct, scalar as u64));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.checked_shl(scalar).unwrap_or(0) % modulus, dec_res);
}
{
let scalar = scalar.saturating_add(nb_bits);
let ct_res = executor.execute((&ct, scalar as u64));
let tmp = executor.execute((&ct, scalar as u64));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_shl(scalar % nb_bits) % modulus, dec_res);
}
}
let clear = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2();
for scalar in 0..nb_bits_in_block {
let ct_res = executor.execute((&ct, scalar as u64));
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_shl(scalar % nb_bits) % modulus, dec_res);
}
}
pub(crate) fn default_scalar_right_shift_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
let nb_bits = modulus.ilog2();
for _ in 0..nb_tests_smaller {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen::<u32>();
let ct = cks.encrypt(clear);
{
let scalar = scalar % nb_bits;
let ct_res = executor.execute((&ct, scalar as u64));
let tmp = executor.execute((&ct, scalar as u64));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_shr(scalar) % modulus, dec_res);
}
{
let scalar = scalar.saturating_add(nb_bits);
let ct_res = executor.execute((&ct, scalar as u64));
let tmp = executor.execute((&ct, scalar as u64));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_shr(scalar % nb_bits) % modulus, dec_res);
}
}
let clear = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2();
for scalar in 0..nb_bits_in_block {
let ct_res = executor.execute((&ct, scalar as u64));
let tmp = executor.execute((&ct, scalar as u64));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp);
let dec_res: u64 = cks.decrypt(&ct_res);
assert_eq!(clear.wrapping_shr(scalar) % modulus, dec_res);
}
}
pub(crate) fn full_propagate_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let nb_ctxt = 4.max(NB_CTXT);
let cks = RadixClientKey::from((cks, nb_ctxt));
let modulus = cks.parameters().message_modulus().0.pow(nb_ctxt as u32);
executor.setup(&cks, sks.clone());
let block_msg_mod = cks.parameters().message_modulus().0;
let block_carry_mod = cks.parameters().carry_modulus().0;
let block_total_mod = block_carry_mod * block_msg_mod;
let clear_max_value = modulus - 1;
for msg in 1..block_msg_mod {
if (block_msg_mod - 1) + msg >= block_total_mod {
continue;
}
let max_value = cks.encrypt(clear_max_value);
let rhs = cks.encrypt(msg);
let mut ct = sks.unchecked_add(&max_value, &rhs);
let shortint_cks = &cks.as_ref().key;
let first_block = shortint_cks.decrypt_message_and_carry(&ct.blocks[0]);
let first_block_msg = first_block % block_msg_mod;
let first_block_carry = first_block / block_msg_mod;
assert_eq!(first_block_msg, (block_msg_mod - 1 + msg) % block_msg_mod);
assert_eq!(first_block_carry, msg.div_ceil(block_msg_mod));
for b in &ct.blocks[1..] {
let block = shortint_cks.decrypt_message_and_carry(b);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;
assert_eq!(msg, block_msg_mod - 1);
assert_eq!(carry, 0);
}
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
let expected_result = clear_max_value.wrapping_add(msg) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, gave ct = {clear_max_value} + {msg}, \
after propagation expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.get() == block_msg_mod - 1),
"Invalid degree after propagation"
);
assert!(
ct.blocks
.iter()
.all(|b| b.noise_level() <= NoiseLevel::NOMINAL),
"Invalid noise_level after propagation"
);
let shortint_cks = &cks.as_ref().key;
assert_eq!(
shortint_cks.decrypt_message_and_carry(&ct.blocks[0]),
(block_msg_mod - 1 + msg) % block_msg_mod
);
for b in &ct.blocks[1..] {
assert_eq!(shortint_cks.decrypt_message_and_carry(b), 0);
}
}
if block_carry_mod >= block_msg_mod {
let mut expected_result = clear_max_value;
let msg = cks.encrypt(clear_max_value);
let mut ct = cks.encrypt(clear_max_value);
while sks.is_add_possible(&ct, &msg).is_ok() {
sks.unchecked_add_assign(&mut ct, &msg);
expected_result = expected_result.wrapping_add(clear_max_value) % modulus;
}
let max_degree_that_can_absorb_carry = (block_total_mod - 1) - (block_carry_mod - 1);
assert!(ct
.blocks
.iter()
.all(|b| { b.degree.get() <= max_degree_that_can_absorb_carry }),);
sks.is_scalar_add_possible(&ct, block_msg_mod - 1).unwrap();
sks.unchecked_scalar_add_assign(&mut ct, block_msg_mod - 1);
assert_eq!(
{ ct.blocks[0].degree.get() },
max_degree_that_can_absorb_carry + (block_msg_mod - 1)
);
expected_result = expected_result.wrapping_add(block_msg_mod - 1) % modulus;
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.get() == block_msg_mod - 1),
"Invalid degree after propagation"
);
assert!(
ct.blocks
.iter()
.all(|b| b.noise_level() <= NoiseLevel::NOMINAL),
"Invalid noise_level after propagation"
);
let expected_block_iter = BlockDecomposer::new(expected_result, block_msg_mod.ilog2())
.iter_as::<u64>()
.take(cks.num_blocks());
let shortint_cks = &cks.as_ref().key;
for (block, expected_msg) in ct.blocks.iter().zip(expected_block_iter) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;
assert_eq!(msg, expected_msg);
assert_eq!(carry, 0);
}
}
{
assert!(cks.num_blocks() >= 4);
assert!(block_msg_mod.is_power_of_two());
let absorber_block_index = 2;
let mut ct = cks.encrypt(clear_max_value);
ct.blocks[absorber_block_index] = cks.encrypt_one_block(0);
let block_mask = block_msg_mod - 1;
let num_bits_in_msg = block_msg_mod.ilog2();
let absorber_block_mask = block_mask << (absorber_block_index as u32 * num_bits_in_msg);
let mask = u64::MAX ^ absorber_block_mask;
let initial_value = clear_max_value & mask;
let to_add = cks.encrypt(block_msg_mod - 1);
sks.unchecked_add_assign(&mut ct, &to_add);
let expected_result = initial_value.wrapping_add(block_msg_mod - 1) % modulus;
let shortint_cks = &cks.as_ref().key;
let mut expected_blocks = vec![block_msg_mod - 1; cks.num_blocks()];
expected_blocks[0] += block_msg_mod - 1;
expected_blocks[absorber_block_index] = 0;
for (block, expected_block) in ct.blocks.iter().zip(expected_blocks) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;
let expected_msg = expected_block % block_msg_mod;
let expected_carry = expected_block / block_msg_mod;
assert_eq!(msg, expected_msg);
assert_eq!(carry, expected_carry);
}
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.get() == block_msg_mod - 1),
"Invalid degree after propagation"
);
assert!(
ct.blocks
.iter()
.all(|b| b.noise_level() <= NoiseLevel::NOMINAL),
"Invalid noise_level after propagation"
);
let mut expected_built_by_hand =
initial_value & (u64::MAX << ((absorber_block_index + 1) as u32 * num_bits_in_msg));
expected_built_by_hand |= (2 * (block_msg_mod - 1)) % block_msg_mod;
expected_built_by_hand |= 1 << (absorber_block_index as u32 * num_bits_in_msg);
assert_eq!(expected_result, expected_built_by_hand);
let expected_block_iter =
BlockDecomposer::new(expected_built_by_hand, block_msg_mod.ilog2())
.iter_as::<u64>()
.take(cks.num_blocks());
let shortint_cks = &cks.as_ref().key;
for (block, expected_msg) in ct.blocks.iter().zip(expected_block_iter) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;
assert_eq!(msg, expected_msg);
assert_eq!(carry, 0);
}
}
{
let block_max_value = block_msg_mod - 1;
let blocks = vec![
cks.encrypt_one_block(block_max_value),
cks.encrypt_one_block(block_max_value),
cks.encrypt_one_block(block_max_value),
cks.encrypt_one_block(block_max_value),
];
let mut ct = RadixCiphertext::from(blocks);
for block in &ct.blocks {
assert_eq!(
block.noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(block_max_value))
);
}
for block in &mut ct.blocks {
block.set_noise_level(NoiseLevel::NOMINAL * 2, sks.key.max_noise_level);
}
executor.execute(&mut ct);
let clean_noise_degree =
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(block_max_value));
assert_eq!(ct.blocks[0].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[1].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[2].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[3].noise_degree(), clean_noise_degree);
}
{
let block_max_value = block_msg_mod - 1;
let blocks = vec![
cks.encrypt_bool(true).0,
cks.encrypt_bool(true).0,
cks.encrypt_one_block(block_max_value),
cks.encrypt_one_block(block_max_value),
];
let mut ct = RadixCiphertext::from(blocks);
assert_eq!(
ct.blocks[0].noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(1))
);
assert_eq!(
ct.blocks[1].noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(1))
);
assert_eq!(
ct.blocks[2].noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(block_max_value))
);
assert_eq!(
ct.blocks[3].noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(block_max_value))
);
let ct_cloned = ct.clone();
let num_ct_to_sum = block_max_value
.min((block_total_mod - 1) / block_max_value)
.min(sks.key.max_noise_level.get());
let num_add = num_ct_to_sum - 1;
for _ in 0..num_add {
sks.unchecked_add_assign(&mut ct, &ct_cloned);
}
assert_eq!(
ct.blocks[0].noise_degree(),
CiphertextNoiseDegree::new(
NoiseLevel::NOMINAL * num_ct_to_sum,
Degree::new(num_ct_to_sum)
)
);
assert_eq!(
ct.blocks[1].noise_degree(),
CiphertextNoiseDegree::new(
NoiseLevel::NOMINAL * num_ct_to_sum,
Degree::new(num_ct_to_sum)
)
);
assert_eq!(
ct.blocks[2].noise_degree(),
CiphertextNoiseDegree::new(
NoiseLevel::NOMINAL * num_ct_to_sum,
Degree::new(block_max_value * num_ct_to_sum)
)
);
assert_eq!(
ct.blocks[3].noise_degree(),
CiphertextNoiseDegree::new(
NoiseLevel::NOMINAL * num_ct_to_sum,
Degree::new(block_max_value * num_ct_to_sum)
)
);
executor.execute(&mut ct);
let clean_noise_degree =
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL, Degree::new(block_max_value));
assert_eq!(ct.blocks[0].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[1].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[2].noise_degree(), clean_noise_degree);
assert_eq!(ct.blocks[3].noise_degree(), clean_noise_degree);
}
{
let mut ct = RadixCiphertext::from(vec![
cks.encrypt_one_block(block_msg_mod - 1),
cks.encrypt_one_block(block_msg_mod - 1),
cks.encrypt_one_block(block_msg_mod - 1),
cks.encrypt_one_block(block_msg_mod - 1),
]);
let ct_cloned = ct.clone();
sks.unchecked_add_assign(&mut ct, &ct_cloned);
for block in &mut ct.blocks[1..] {
block.set_noise_level(
NoiseLevel::NOMINAL * sks.key.max_noise_level.get(),
sks.key.max_noise_level,
);
}
executor.execute(&mut ct);
let clean_degree = Degree::new(block_msg_mod - 1);
for block in &mut ct.blocks {
assert_eq!(block.noise_level(), NoiseLevel::NOMINAL);
assert_eq!(block.degree, clean_degree);
}
let decrypted: u64 = cks.decrypt(&ct);
let modulus = cks
.parameters()
.message_modulus()
.0
.pow(ct.blocks.len() as u32);
let expected = ((modulus - 1) * 2) % modulus;
assert_eq!(decrypted, expected);
}
{
let mut ct = RadixCiphertext::from(vec![
cks.encrypt_one_block(block_msg_mod - 1),
cks.encrypt_one_block(0),
cks.encrypt_one_block(0),
cks.encrypt_one_block(0),
]);
let ct_cloned = ct.clone();
sks.unchecked_add_assign(&mut ct, &ct_cloned);
for block in &mut ct.blocks[1..] {
block.degree = Degree::new(0);
}
assert_eq!(
ct.blocks[0].noise_degree(),
CiphertextNoiseDegree::new(
NoiseLevel::NOMINAL * 2,
Degree::new((block_msg_mod - 1) * 2)
)
);
for block in &mut ct.blocks[1..] {
assert_eq!(
block.noise_degree(),
CiphertextNoiseDegree::new(NoiseLevel::NOMINAL * 2, Degree::new(0))
);
}
executor.execute(&mut ct);
let clean_degree = Degree::new(block_msg_mod - 1);
for block in &mut ct.blocks {
assert_eq!(block.noise_level(), NoiseLevel::NOMINAL);
assert!(block.degree <= clean_degree);
}
let decrypted: u64 = cks.decrypt(&ct);
assert_eq!(decrypted, (block_msg_mod - 1) * 2);
}
{
let nb_blocks = 4;
let num_bits_in_msg = sks.message_modulus().0.ilog2();
let clear_a = (1u32 << num_bits_in_msg) - 1;
let clear_b = ((1 << num_bits_in_msg) - 1) << (num_bits_in_msg * (nb_blocks - 1)) | 1u32;
let a = cks.as_ref().encrypt_radix(clear_a, nb_blocks as usize);
let b = cks.as_ref().encrypt_radix(clear_b, nb_blocks as usize);
let mut a = sks.cast_to_unsigned(a, 2 * nb_blocks as usize);
let b = sks.cast_to_unsigned(b, 2 * nb_blocks as usize);
sks.unchecked_add_assign(&mut a, &b);
executor.execute(&mut a);
assert!(a.block_carries_are_empty());
assert!(a
.blocks
.iter()
.all(|b| b.noise_level() <= NoiseLevel::NOMINAL));
let result: u32 = cks.as_ref().decrypt_radix(&a);
let expected = clear_a + clear_b;
assert_eq!(
result,
clear_a + clear_b,
"Invalid full propagation result for {clear_a} + {clear_b}, expected {expected}, got {result}"
);
}
}