use super::*;
use crate::core_crypto::prelude::new_seeder;
use crate::prelude::*;
use crate::shortint::parameters::test_params::*;
use crate::xof_key_set::{CompressedXofKeySet, XofKeySet};
use crate::*;
fn run_xof_key_set_test(config: Config, tag_str: &str, device: Device, check_expansion: bool) {
let mut seeder = new_seeder();
let private_seed_bytes = seeder.seed().0.to_le_bytes().to_vec();
let security_bits = 128;
let max_norm_hwt = NormalizedHammingWeightBound::new(0.8).unwrap();
let tag = Tag::from(tag_str);
let (cks, compressed_key_set) = CompressedXofKeySet::generate(
config,
private_seed_bytes,
security_bits,
max_norm_hwt,
tag.clone(),
)
.unwrap();
assert_eq!(cks.tag(), compressed_key_set.compressed_public_key.tag());
assert_eq!(cks.tag(), &tag);
assert_eq!(compressed_key_set.tag(), &tag);
test_xof_key_set(&compressed_key_set, config, device, &cks);
if check_expansion {
test_xof_expansion_is_same_as_classic(compressed_key_set);
}
}
mod cpu {
use super::*;
#[test]
fn test_xof_key_set_legacy_rerand_classic_params() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"classic_2_2",
Device::Cpu,
false,
);
}
#[test]
fn test_xof_key_set_legacy_rerand_ks32_params_big_pke() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128.into(),
"ks32 big pke",
Device::Cpu,
false,
);
}
#[test]
fn test_xof_key_set_legacy_rerand_ks32_params_small_pke() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"ks32 small pke",
Device::Cpu,
false,
);
}
#[test]
fn test_xof_key_set_classic_params() {
run_xof_key_set_test(
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"classic_2_2",
Device::Cpu,
true,
);
}
#[test]
fn test_xof_key_set_ks32_params_big_pke() {
run_xof_key_set_test(
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128.into(),
"ks32 big pke",
Device::Cpu,
true,
);
}
#[test]
fn test_xof_key_set_ks32_params_small_pke() {
run_xof_key_set_test(
TEST_META_PARAM_CPU_2_2_KS32_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"ks32 small pke",
Device::Cpu,
true,
);
}
}
#[cfg(feature = "gpu")]
mod gpu {
use super::*;
#[test]
fn test_xof_key_set_legacy_rerand_multibit_group_4_small_pke() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"gpu multibit group 4",
Device::CudaGpu,
true,
);
}
#[test]
fn test_xof_key_set_legacy_rerand_multibit_group_4_big_pke() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128.into(),
"gpu multibit group 4",
Device::CudaGpu,
true,
);
}
#[test]
fn test_xof_key_set_multibit_group_4_small_pke() {
run_xof_key_set_test(
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128
.into(),
"gpu multibit group 4 small pke",
Device::CudaGpu,
true,
);
}
#[test]
fn test_xof_key_set_multibit_group_4_big_pke() {
run_xof_key_set_test(
TEST_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128.into(),
"gpu multibit group 4 big pke",
Device::CudaGpu,
true,
);
}
#[test]
fn test_xof_key_set_with_cpu_params() {
run_xof_key_set_test(
TEST_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"gpu with cpu v1.6 params",
Device::CudaGpu,
true,
);
}
#[test]
fn test_xof_key_set_legacy_rerand_with_cpu_params() {
run_xof_key_set_test(
TEST_LEGACY_RERAND_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_SMALL_ZKV2_TUNIFORM_2M128.into(),
"gpu with cpu params",
Device::CudaGpu,
true,
);
}
}
fn test_xof_expansion_is_same_as_classic(key_set: CompressedXofKeySet) {
let (xof_pk, xof_sk) = key_set.expand();
let (_seed, cpk, csk) = key_set.into_raw_parts();
assert_eq!(cpk.tag(), csk.tag());
let pk = cpk.decompress();
let sk = csk.integer_key.expand();
#[allow(
clippy::manual_assert,
reason = "The type does not impl Debug, and if it did, the output would be unreadable"
)]
if sk != xof_sk {
panic!("Expanded server keys are not equal");
}
assert_eq!(pk, xof_pk);
}
fn test_xof_key_set(
compressed_key_set: &CompressedXofKeySet,
config: Config,
device: Device,
cks: &ClientKey,
) {
let compressed_size_limit = 1 << 32; let mut data = vec![];
crate::safe_serialization::safe_serialize(compressed_key_set, &mut data, compressed_size_limit)
.unwrap();
let compressed_key_set: CompressedXofKeySet =
crate::safe_serialization::safe_deserialize(data.as_slice(), compressed_size_limit)
.unwrap();
let expected_tag = cks.tag().clone();
assert_eq!(compressed_key_set.tag(), &expected_tag);
{
let mut compressed_clone = compressed_key_set.clone();
compressed_clone.tag_mut().set_u64(0xDEAD);
assert_eq!(compressed_clone.tag().as_u64(), 0xDEAD);
assert_eq!(compressed_key_set.tag(), &expected_tag);
let (_seed, cpk, csk) = compressed_clone.into_raw_parts();
assert_eq!(cpk.tag().as_u64(), 0xDEAD);
assert_eq!(csk.tag().as_u64(), 0xDEAD);
}
assert!(compressed_key_set.is_conformant(&config));
let cpk = match device {
Device::Cpu => {
let key_set = compressed_key_set.decompress().unwrap();
let size_limit = 1 << 32; let mut data = vec![];
crate::safe_serialization::safe_serialize(&key_set, &mut data, size_limit).unwrap();
let mut key_set: XofKeySet =
crate::safe_serialization::safe_deserialize(data.as_slice(), size_limit).unwrap();
assert_eq!(key_set.tag(), &expected_tag);
key_set.tag_mut().set_u64(0xCAFE);
assert_eq!(key_set.tag().as_u64(), 0xCAFE);
let (pk, sk) = key_set.into_raw_parts();
assert_eq!(pk.tag().as_u64(), 0xCAFE);
assert_eq!(sk.tag().as_u64(), 0xCAFE);
assert!(sk.is_conformant(&config.into()));
set_server_key(sk);
pk
}
#[cfg(feature = "gpu")]
Device::CudaGpu => {
let mut key_set = compressed_key_set.decompress_to_gpu().unwrap();
assert_eq!(key_set.tag(), &expected_tag);
key_set.tag_mut().set_u64(0xCAFE);
assert_eq!(key_set.tag().as_u64(), 0xCAFE);
let (pk, sk) = key_set.into_raw_parts();
assert_eq!(pk.tag().as_u64(), 0xCAFE);
assert_eq!(sk.tag().as_u64(), 0xCAFE);
set_server_key(sk);
pk
}
#[cfg(feature = "hpu")]
Device::Hpu => {
panic!("HPU not supported in this test")
}
};
let cpk = &cpk;
let clear_a = rand::random::<u32>();
let clear_b = rand::random::<u32>();
{
let a = FheUint32::encrypt(clear_a, cks);
let b = FheUint32::encrypt(clear_b, cks);
let c = &a * &b;
let d = &a & &b;
let c_dec: u32 = c.decrypt(cks);
let d_dec: u32 = d.decrypt(cks);
assert_eq!(clear_a.wrapping_mul(clear_b), c_dec);
assert_eq!(clear_a & clear_b, d_dec);
}
for build_packed in [true, false] {
#[cfg(feature = "gpu")]
if !build_packed && device == Device::CudaGpu {
continue;
}
let mut builder = CompactCiphertextList::builder(cpk);
builder.push(clear_a).push(clear_b);
let list = if build_packed {
builder.build_packed()
} else {
builder.build()
};
let expander = list.expand().unwrap();
let mut a = expander.get::<FheUint32>(0).unwrap().unwrap();
let mut b = expander.get::<FheUint32>(1).unwrap().unwrap();
if config.inner.cpk_re_randomization_params.is_some() {
let nonce: [u8; 256 / 8] = core::array::from_fn(|_| rand::random());
let compact_public_encryption_domain_separator = *b"TFHE_Enc";
let rerand_domain_separator = *b"TFHE_Rrd";
let mut re_rand_context = ReRandomizationContext::new(
rerand_domain_separator,
[b"FheUint32 bin ops".as_slice(), nonce.as_slice()],
compact_public_encryption_domain_separator,
);
re_rand_context.add_ciphertext(&a);
re_rand_context.add_ciphertext(&b);
let mut seed_gen = re_rand_context.finalize();
match ServerKey::current_server_key_re_randomization_support().unwrap() {
ReRandomizationSupport::NoSupport => {
panic!("This test runs rerand, the current ServerKey does not support it")
}
ReRandomizationSupport::LegacyDedicatedCPKWithKeySwitch => {
a.re_randomize(
ReRandomizationMode::UseLegacyCPKIfNeeded { cpk },
seed_gen.next_seed().unwrap(),
)
.unwrap();
b.re_randomize(
ReRandomizationMode::UseLegacyCPKIfNeeded { cpk },
seed_gen.next_seed().unwrap(),
)
.unwrap();
}
ReRandomizationSupport::DerivedCPKWithoutKeySwitch => {
a.re_randomize(
ReRandomizationMode::UseAvailableMode,
seed_gen.next_seed().unwrap(),
)
.unwrap();
b.re_randomize(
ReRandomizationMode::UseAvailableMode,
seed_gen.next_seed().unwrap(),
)
.unwrap();
}
}
}
let c = &a * &b;
let d = &a & &b;
let c_dec: u32 = c.decrypt(cks);
let d_dec: u32 = d.decrypt(cks);
assert_eq!(clear_a.wrapping_mul(clear_b), c_dec);
assert_eq!(clear_a & clear_b, d_dec);
let ns_c = c.squash_noise().unwrap();
let ns_c_dec: u32 = ns_c.decrypt(cks);
assert_eq!(clear_a.wrapping_mul(clear_b), ns_c_dec);
let ns_d = d.squash_noise().unwrap();
let ns_d_dec: u32 = ns_d.decrypt(cks);
assert_eq!(clear_a & clear_b, ns_d_dec);
let compressed_list = CompressedCiphertextListBuilder::new()
.push(a)
.push(b)
.push(c)
.push(d)
.build()
.unwrap();
let a: FheUint32 = compressed_list.get(0).unwrap().unwrap();
let da: u32 = a.decrypt(cks);
assert_eq!(da, clear_a);
let b: FheUint32 = compressed_list.get(1).unwrap().unwrap();
let db: u32 = b.decrypt(cks);
assert_eq!(db, clear_b);
let c: FheUint32 = compressed_list.get(2).unwrap().unwrap();
let dc: u32 = c.decrypt(cks);
assert_eq!(dc, clear_a.wrapping_mul(clear_b));
let d: FheUint32 = compressed_list.get(3).unwrap().unwrap();
let db: u32 = d.decrypt(cks);
assert_eq!(db, clear_a & clear_b);
let ns_compressed_list = CompressedSquashedNoiseCiphertextListBuilder::new()
.push(ns_c)
.push(ns_d)
.build()
.unwrap();
let ns_c: SquashedNoiseFheUint = ns_compressed_list.get(0).unwrap().unwrap();
let dc: u32 = ns_c.decrypt(cks);
assert_eq!(dc, clear_a.wrapping_mul(clear_b));
let ns_d: SquashedNoiseFheUint = ns_compressed_list.get(1).unwrap().unwrap();
let db: u32 = ns_d.decrypt(cks);
assert_eq!(db, clear_a & clear_b);
}
}