use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix::neg::NegatedDegreeIter;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{FunctionExecutor, NB_CTXT};
use crate::integer::server_key::radix_parallel::tests_unsigned::{
nb_tests_for_params, nb_tests_smaller_for_params, overflowing_sub_under_modulus,
panic_if_any_block_info_exceeds_max_degree_or_noise, panic_if_any_block_is_not_clean,
panic_if_any_block_values_exceeds_its_degree, random_non_zero_value, unsigned_modulus,
unsigned_modulus_u128, CpuFunctionExecutor, ExpectedDegrees, ExpectedNoiseLevels,
};
use crate::integer::tests::create_parameterized_test;
use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::test_params::*;
use crate::shortint::parameters::*;
use rand::Rng;
use std::sync::Arc;
use super::MAX_NB_CTXT;
create_parameterized_test!(integer_unchecked_sub);
create_parameterized_test!(integer_smart_sub);
create_parameterized_test!(integer_default_sub);
create_parameterized_test!(integer_extensive_trivial_default_sub);
create_parameterized_test!(integer_default_overflowing_sub);
create_parameterized_test!(integer_extensive_trivial_default_overflowing_sub);
create_parameterized_test!(integer_advanced_sub_assign_with_borrow_at_least_4_bits {
coverage => {
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
},
no_coverage => {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64
}
});
create_parameterized_test!(integer_advanced_sub_assign_with_borrow_sequential);
fn integer_unchecked_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_sub);
unchecked_sub_test(param, executor);
}
fn integer_smart_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::smart_sub_parallelized);
smart_sub_test(param, executor);
}
fn integer_default_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
default_sub_test(param, executor);
}
fn integer_extensive_trivial_default_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
extensive_trivial_default_sub_test(param, executor);
}
fn integer_default_overflowing_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
default_overflowing_sub_test(param, executor);
}
fn integer_extensive_trivial_default_overflowing_sub<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
extensive_trivial_default_overflowing_sub_test(param, executor);
}
fn integer_advanced_sub_assign_with_borrow_at_least_4_bits<P>(param: P)
where
P: Into<TestParameters>,
{
let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| {
let mut result = lhs.clone();
if !result.block_carries_are_empty() {
sks.full_propagate_parallelized(&mut result);
}
let mut tmp_rhs;
let rhs = if rhs.block_carries_are_empty() {
rhs
} else {
tmp_rhs = rhs.clone();
sks.full_propagate_parallelized(&mut tmp_rhs);
&tmp_rhs
};
let overflowed = sks
.advanced_sub_assign_with_borrow_parallelized_at_least_4_bits(
&mut result,
rhs,
None,
true,
)
.expect("Overflow flag was requested");
(result, overflowed)
};
let executor = CpuFunctionExecutor::new(&func);
default_overflowing_sub_test(param, executor);
}
fn integer_advanced_sub_assign_with_borrow_sequential<P>(param: P)
where
P: Into<TestParameters>,
{
let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| {
let mut result = lhs.clone();
if !result.block_carries_are_empty() {
sks.full_propagate_parallelized(&mut result);
}
let mut tmp_rhs;
let rhs = if rhs.block_carries_are_empty() {
rhs
} else {
tmp_rhs = rhs.clone();
sks.full_propagate_parallelized(&mut tmp_rhs);
&tmp_rhs
};
let overflowed = sks
.advanced_sub_assign_with_borrow_sequential(&mut result, rhs, None, true)
.expect("Overflow flag was requested");
(result, overflowed)
};
let executor = CpuFunctionExecutor::new(&func);
default_overflowing_sub_test(param, executor);
}
impl ExpectedDegrees {
fn after_unchecked_sub(&mut self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> &Self {
let negated_rhs_degrees = NegatedDegreeIter::new(
rhs.blocks
.iter()
.map(|block| (block.degree, block.message_modulus)),
);
self.set_with(
lhs.blocks
.iter()
.zip(negated_rhs_degrees)
.map(|(block, negated_rhs_degree)| block.degree + negated_rhs_degree),
);
self
}
}
impl ExpectedNoiseLevels {
fn after_unchecked_sub(&mut self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> &Self {
self.set_with(
lhs.blocks
.iter()
.zip(rhs.blocks.iter())
.map(|(a, b)| a.noise_level() + b.noise_level()),
);
self
}
}
pub(crate) fn unchecked_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = unsigned_modulus(
cks.parameters().message_modulus(),
crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT as u32,
);
let max_noise_level = sks.key.max_noise_level;
let max_degree = sks.key.max_degree;
executor.setup(&cks, sks);
let mut expected_noise_levels = ExpectedNoiseLevels::new(
NoiseLevel::ZERO,
crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT,
);
let mut expected_degrees = ExpectedDegrees::new(
Degree::new(0),
crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT,
);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let encrypted_result = executor.execute((&ctxt_0, &ctxt_1));
expected_noise_levels
.after_unchecked_sub(&ctxt_0, &ctxt_1)
.panic_if_any_is_not_equal(&encrypted_result);
expected_degrees
.after_unchecked_sub(&ctxt_0, &ctxt_1)
.panic_if_any_is_not_equal(&encrypted_result);
panic_if_any_block_values_exceeds_its_degree(&encrypted_result, &cks);
panic_if_any_block_info_exceeds_max_degree_or_noise(
&encrypted_result,
max_degree,
max_noise_level,
);
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let expected_result = clear_0.wrapping_sub(clear_1) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid sub result, expected {clear_0} - {clear_1} \
to be {expected_result}, but got {decrypted_result}."
);
}
}
pub(crate) fn smart_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a mut RadixCiphertext, &'a mut RadixCiphertext),
RadixCiphertext,
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT));
let mut rng = rand::thread_rng();
let modulus = unsigned_modulus(cks.parameters().message_modulus(), NB_CTXT as u32);
let max_noise_level = sks.key.max_noise_level;
let max_degree = sks.key.max_degree;
executor.setup(&cks, sks);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
let ctxt_1 = cks.encrypt(clear1);
let mut ctxt_2 = cks.encrypt(clear2);
let mut res = ctxt_1.clone();
let mut clear = clear1;
for _ in 0..nb_tests_smaller {
res = executor.execute((&mut res, &mut ctxt_2));
panic_if_any_block_info_exceeds_max_degree_or_noise(&res, max_degree, max_noise_level);
panic_if_any_block_values_exceeds_its_degree(&res, &cks);
clear = clear.wrapping_sub(clear2) % modulus;
let dec_res: u64 = cks.decrypt(&res);
assert_eq!(clear, dec_res);
}
}
}
pub(crate) fn default_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks);
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
let ctxt_1 = cks.as_ref().encrypt_radix(clear1, num_blocks);
let ctxt_2 = cks.as_ref().encrypt_radix(clear2, num_blocks);
let mut res = ctxt_1.clone();
let mut clear = clear1;
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&res, &ctxt_2));
res = executor.execute((&res, &ctxt_2));
assert!(res.block_carries_are_empty());
assert_eq!(res, tmp);
panic_if_any_block_is_not_clean(&res, &cks);
clear = (clear.wrapping_sub(clear2)) % modulus;
let dec: u64 = cks.decrypt(&res);
assert_eq!(clear, dec);
}
}
}
}
pub(crate) fn extensive_trivial_default_sub_test<P, T>(param: P, mut sub_executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((
cks,
crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT,
));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
sub_executor.setup(&cks, sks.clone());
let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
for bit_size in (1..=64u32).step_by(block_num_bits as usize) {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = unsigned_modulus_u128(cks.parameters().message_modulus(), num_blocks);
for _ in 0..50 {
let clear_0 = rng.gen::<u128>() % modulus;
let clear_1 = rng.gen::<u128>() % modulus;
let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize);
let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize);
let ct_res = sub_executor.execute((&ctxt_0, &ctxt_1));
let dec_res: u128 = cks.decrypt(&ct_res);
let expected_clear = clear_0.wrapping_sub(clear_1) % modulus;
assert_eq!(
expected_clear, dec_res,
"Invalid result for {clear_0} - {clear_1}, expected: {expected_clear}, got: {dec_res}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
}
}
}
pub(crate) fn default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
executor.setup(&cks, sks.clone());
let cks: crate::integer::ClientKey = cks.into();
for num_blocks in 1..MAX_NB_CTXT {
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks);
let ctxt_1 = cks.encrypt_radix(clear_1, num_blocks);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nctxt1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nctxt0: {ctxt_0:?}, \n\n\nctxt1: {ctxt_1:?}\n\n\n");
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} ({num_blocks} blocks) \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..nb_tests_smaller {
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);
let clear_lhs = clear_0.wrapping_add(clear_2) % modulus;
let clear_rhs = clear_1.wrapping_add(clear_3) % modulus;
let d0: u64 = cks.decrypt_radix(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: u64 = cks.decrypt_radix(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} ({num_blocks} blocks) \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} ({num_blocks} blocks) \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % (modulus / sks.key.message_modulus.0);
let clear_1 = rng.gen::<u64>() % (modulus / sks.key.message_modulus.0);
let a: RadixCiphertext = sks.create_trivial_radix(clear_0, num_blocks);
let b: RadixCiphertext = sks.create_trivial_radix(clear_1, num_blocks);
assert_eq!(a.blocks[num_blocks - 1].degree.get(), 0);
assert_eq!(b.blocks[num_blocks - 1].degree.get(), 0);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (expected_result, expected_overflowed) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert!(encrypted_overflow.0.noise_level() <= NoiseLevel::NOMINAL);
}
}
}
pub(crate) fn extensive_trivial_default_overflowing_sub_test<P, T>(
param: P,
mut overflowing_sub_executor: T,
) where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((
cks,
crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT,
));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
overflowing_sub_executor.setup(&cks, sks.clone());
let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
for bit_size in 1..=64u32 {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = unsigned_modulus_u128(cks.parameters().message_modulus(), num_blocks);
for _ in 0..50 {
let clear_0 = rng.gen::<u128>() % modulus;
let clear_1 = rng.gen::<u128>() % modulus;
let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize);
let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize);
let (ct_res, o_res) = overflowing_sub_executor.execute((&ctxt_0, &ctxt_1));
let dec_res: u128 = cks.decrypt(&ct_res);
let dec_overflow = cks.decrypt_bool(&o_res);
let (expected_clear, expected_overflow) =
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(
expected_clear, dec_res,
"Invalid result for {clear_0} - {clear_1}, expected: {expected_clear}, got: {dec_res}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
assert_eq!(
expected_overflow, dec_overflow,
"Invalid overflow result for {clear_0} - {clear_1}, expected: {expected_overflow}, got: {dec_overflow}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
}
}
}