use concrete_npe as npe;
use crate::crypto::encoding::{Cleartext, CleartextList, Plaintext, PlaintextList};
use crate::crypto::lwe::{LweCiphertext, LweKeyswitchKey, LweList};
use crate::crypto::secret::LweSecretKey;
use crate::crypto::{CiphertextCount, CleartextCount, LweDimension, PlaintextCount, UnsignedTorus};
use crate::math::decomposition::{DecompositionBaseLog, DecompositionLevelCount};
use crate::math::dispersion::{DispersionParameter, LogStandardDev, Variance};
use crate::math::random;
use crate::math::random::{
fill_with_random_uniform, random_uniform_n_msb_tensor, RandomGenerable, UniformMsb,
};
use crate::math::tensor::{AsMutTensor, AsRefTensor, Tensor};
use crate::numeric::{CastFrom, Numeric, SignedInteger};
use crate::test_tools::{
assert_delta_std_dev, assert_noise_distribution, random_ciphertext_count, random_lwe_dimension,
random_utorus_between,
};
fn test_keyswitch<T: UnsignedTorus + RandomGenerable<UniformMsb> + npe::LWE>() {
let n_bit_msg = 8; let nb_ct = random_ciphertext_count(100); let base_log = DecompositionBaseLog(3); let level_count = DecompositionLevelCount(8); let messages = PlaintextList::from_tensor(random_uniform_n_msb_tensor(nb_ct.0, n_bit_msg));
let std_input = LogStandardDev::from_log_standard_dev(-10.); let std_ksk = LogStandardDev::from_log_standard_dev(-25.);
let dimension_after = LweDimension(600);
let sk_after = LweSecretKey::generate(dimension_after);
let dimension_before = LweDimension(1024);
let sk_before = LweSecretKey::generate(dimension_before);
let mut ciphertexts_before = LweList::allocate(T::ZERO, dimension_before.to_lwe_size(), nb_ct);
let mut ciphertexts_after = LweList::allocate(T::ZERO, dimension_after.to_lwe_size(), nb_ct);
let mut ksk = LweKeyswitchKey::allocate(
T::ZERO,
level_count,
base_log,
dimension_before,
dimension_after,
);
ksk.fill_with_keyswitch_key(&sk_before, &sk_after, std_ksk.clone());
sk_before.encrypt_lwe_list(&mut ciphertexts_before, &messages, std_input.clone());
ksk.keyswitch_list(&mut ciphertexts_after, &ciphertexts_before);
let mut dec_messages = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
sk_after.decrypt_lwe_list(&mut dec_messages, &ciphertexts_after);
let output_variance = <T as npe::LWE>::key_switch(
dimension_before.0,
level_count.0,
base_log.0,
std_ksk.get_variance(),
std_input.get_variance(),
);
if nb_ct.0 < 7 {
assert_delta_std_dev(
&messages,
&dec_messages,
Variance::from_variance(output_variance),
);
} else {
assert_noise_distribution(
&messages,
&dec_messages,
Variance::from_variance(output_variance),
);
}
}
#[test]
fn test_keyswitch_u32() {
test_keyswitch::<u32>();
}
#[test]
fn test_keyswitch_u64() {
test_keyswitch::<u64>();
}
fn test_encrypt_decrypt<T: UnsignedTorus>() {
let nb_ct = random_ciphertext_count(100000);
let dimension = random_lwe_dimension(1000);
let std_dev = LogStandardDev::from_log_standard_dev(-25.);
let sk = LweSecretKey::generate(dimension);
let messages = PlaintextList::from_tensor(random::random_uniform_tensor(nb_ct.0));
let mut ciphertexts = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.encrypt_lwe_list(&mut ciphertexts, &messages, std_dev);
let mut decryptions = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
sk.decrypt_lwe_list(&mut decryptions, &ciphertexts);
if nb_ct.0 < 7 {
assert_delta_std_dev(&messages, &decryptions, std_dev);
} else {
assert_noise_distribution(&messages, &decryptions, std_dev);
}
}
#[test]
fn test_encrypt_decrypt_u32() {
test_encrypt_decrypt::<u32>()
}
#[test]
fn test_encrypt_decrypt_u64() {
test_encrypt_decrypt::<u64>()
}
fn test_multisum_npe<T>()
where
T: UnsignedTorus + RandomGenerable<UniformMsb> + npe::LWE + CastFrom<usize>,
{
let mut new_msg = Tensor::allocate(T::ZERO, 100);
let mut msg = Tensor::allocate(T::ZERO, 100);
let nb_ct = random_ciphertext_count(100);
let dimension = random_lwe_dimension(1000);
let std = LogStandardDev::from_log_standard_dev(-25.);
let mut weights = CleartextList::allocate(T::ZERO, CleartextCount(nb_ct.0));
let mut s_weights = CleartextList::allocate(T::Signed::ZERO, CleartextCount(nb_ct.0));
for (w, sw) in weights
.as_mut_tensor()
.iter_mut()
.zip(s_weights.as_mut_tensor().iter_mut())
{
*sw = random_utorus_between::<T>(T::ZERO..T::cast_from(512)).into_signed()
- T::cast_from(256).into_signed();
*w = sw.into_unsigned();
}
let bias = Plaintext(random_utorus_between::<T>(T::ZERO..T::cast_from(1024)));
let n_tests = 10;
for i in 0..n_tests {
let sk = LweSecretKey::generate(dimension);
let mut messages = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
fill_with_random_uniform(&mut messages);
let mut witness = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.trivial_encrypt_lwe_list(&mut witness, &messages, std);
let mut ciphertext = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.encrypt_lwe_list(&mut ciphertext, &messages, std);
let mut ct_res = LweCiphertext::allocate(T::ZERO, dimension.to_lwe_size());
let mut ct_res_witness = LweCiphertext::allocate(T::ZERO, dimension.to_lwe_size());
ct_res.fill_with_multisum_with_bias(&ciphertext, &weights, &bias);
ct_res_witness.fill_with_multisum_with_bias(&witness, &weights, &bias);
let mut output = Plaintext(T::ZERO);
sk.decrypt_lwe(&mut output, &ct_res);
new_msg.set_element(i, output.0);
msg.set_element(i, ct_res_witness.get_body().0);
}
let mut weights: Vec<T> = vec![T::ZERO; s_weights.as_tensor().len()];
for (w, sw) in weights.iter_mut().zip(s_weights.as_tensor().iter()) {
*w = sw.into_unsigned();
}
let output_variance: f64 =
<T as npe::LWE>::multisum_uncorrelated(&vec![f64::powi(std.0, 2); nb_ct.0], &weights);
if n_tests < 7 {
assert_delta_std_dev(&new_msg, &msg, Variance::from_variance(output_variance));
} else {
assert_noise_distribution(&msg, &new_msg, Variance::from_variance(output_variance));
}
}
#[test]
fn test_multisum_u32() {
test_multisum_npe::<u32>();
}
#[test]
fn test_multisum_u64() {
test_multisum_npe::<u64>();
}
fn test_scalar_mul<T>()
where
T: UnsignedTorus + RandomGenerable<UniformMsb> + npe::LWE + CastFrom<usize>,
{
let n_tests = 10;
let nb_ct = CiphertextCount(n_tests);
let dimension = LweDimension(600);
let std_dev = LogStandardDev::from_log_standard_dev(-15.);
let sk = LweSecretKey::generate(dimension);
let mut messages = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
fill_with_random_uniform(&mut messages);
let mut ciphertexts = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.encrypt_lwe_list(&mut ciphertexts, &messages, std_dev);
let weight = Cleartext(
(random_utorus_between::<T>(T::ZERO..T::cast_from(1024)).into_signed()
- T::cast_from(512).into_signed())
.into_unsigned(),
);
let mut ciphertext_sm = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
ciphertext_sm
.ciphertext_iter_mut()
.zip(ciphertexts.ciphertext_iter())
.for_each(|(mut out, inp)| out.fill_with_scalar_mul(&inp, &weight));
let mut messages_mul = Tensor::allocate(T::ZERO, nb_ct.0);
messages_mul.fill_with_one(messages.as_tensor(), |m| m.wrapping_mul(weight.0));
let mut decryptions = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
sk.decrypt_lwe_list(&mut decryptions, &ciphertext_sm);
let output_variance: f64 =
<T as npe::LWE>::single_scalar_mul(f64::powi(std_dev.0, 2), weight.0);
if nb_ct.0 < 7 {
assert_delta_std_dev(
&messages_mul,
&decryptions,
Variance::from_variance(output_variance),
);
} else {
assert_noise_distribution(
&messages_mul,
&decryptions,
Variance::from_variance(output_variance),
);
}
}
#[test]
fn test_scalar_mul_u32() {
test_scalar_mul::<u32>();
}
#[test]
fn test_scalar_mul_u64() {
test_scalar_mul::<u64>();
}
fn test_scalar_mul_random<T>()
where
T: UnsignedTorus + RandomGenerable<UniformMsb> + npe::LWE + CastFrom<usize>,
{
let nb_ct = random_ciphertext_count(100);
let dimension = random_lwe_dimension(1000);
let std_dev = LogStandardDev::from_log_standard_dev(-15.);
let sk = LweSecretKey::generate(dimension);
let mut messages = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
fill_with_random_uniform(&mut messages);
let mut ciphertexts = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.encrypt_lwe_list(&mut ciphertexts, &messages, std_dev);
let mut weights = CleartextList::allocate(T::ZERO, CleartextCount(nb_ct.0));
for w in weights.as_mut_tensor().iter_mut() {
let val = random_utorus_between::<T>(T::ZERO..T::cast_from(1024)).into_signed()
- T::cast_from(512).into_signed();
*w = val.into_unsigned();
}
let mut ciphertexts_out = LweList::allocate(T::ZERO, dimension.to_lwe_size(), nb_ct);
sk.encrypt_lwe_list(&mut ciphertexts, &messages, std_dev);
for (mut out, (inp, w)) in ciphertexts_out
.ciphertext_iter_mut()
.zip(ciphertexts.ciphertext_iter().zip(weights.cleartext_iter()))
{
out.fill_with_scalar_mul(&inp, &w);
}
let mut messages_mul = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
for (mm, (w_i, m)) in messages_mul
.plaintext_iter_mut()
.zip(weights.cleartext_iter().zip(messages.plaintext_iter()))
{
*mm = Plaintext(w_i.0.wrapping_mul(m.0));
}
let mut decryptions = PlaintextList::allocate(T::ZERO, PlaintextCount(nb_ct.0));
sk.decrypt_lwe_list(&mut decryptions, &ciphertexts_out);
for (mm, (d, w_i)) in messages_mul.sublist_iter(PlaintextCount(1)).zip(
decryptions
.sublist_iter(PlaintextCount(1))
.zip(weights.cleartext_iter()),
) {
let output_variance: f64 =
<T as npe::LWE>::single_scalar_mul(f64::powi(std_dev.0, 2), w_i.0);
assert_delta_std_dev(&mm, &d, Variance::from_variance(output_variance));
}
}
#[test]
fn test_scalar_mul_random_u32() {
test_scalar_mul_random::<u32>()
}
#[test]
fn test_scalar_mul_random_u64() {
test_scalar_mul_random::<u64>()
}