use poulpy_hal::{
api::{NegacyclicFFT, NegacyclicFFTNew, ScratchOwnedBorrow},
layouts::{HostBytesBackend, Module},
};
use crate::{CKKSInfos, layouts::UnnormalizedCKKSCiphertext, leveled::api::CKKSAddOpsUnnormalized};
use super::helpers::{
ADD_SUB_CONST, PT_PREC, TestContextBackend, TestContextModule, TestScalar, add_sub_const_pt, alloc_ct, alloc_scratch,
assert_binary_output_meta, assert_ct_meta, assert_decrypt_precision, assert_unary_output_meta, ckks_encrypt,
encode_and_upload_pt, gen_sk, quantize, quantized_const, test_vector_1, test_vector_2, want_add, want_add_const,
};
use crate::{encoding::reim::Encoder, test_suite::CKKSTestParams};
pub fn test_add_ct_aligned_unsafe<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let (re2, im2) = test_vector_2::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct1 = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let ct2 = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re2,
&im2,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_add(&re1, &im1, &re2, &im2);
let mut ct_res = UnnormalizedCKKSCiphertext::new(alloc_ct(¶ms, module, params.k));
module
.ckks_add_into_unnormalized(&mut ct_res, &ct1, &ct2, &mut scratch.borrow())
.unwrap();
assert_binary_output_meta("add_ct_aligned_unsafe", &ct_res, &ct1, &ct2);
let ct_res = ct_res.normalize(module, &mut scratch.borrow());
assert_decrypt_precision(
"add_ct_aligned_unsafe",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_add_ct_assign_aligned_unsafe<BE, F, E>(
params: CKKSTestParams,
module: &Module<BE>,
host_module: &Module<HostBytesBackend>,
) where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let (re2, im2) = test_vector_2::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct1_raw = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let ct2 = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re2,
&im2,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_add(&re1, &im1, &re2, &im2);
let expected_log_budget = ct1_raw.log_budget().min(ct2.log_budget());
let expected_log_delta = ct1_raw.log_delta().max(ct2.log_delta());
let mut ct1 = UnnormalizedCKKSCiphertext::new(ct1_raw);
module
.ckks_add_assign_unnormalized(&mut ct1, &ct2, &mut scratch.borrow())
.unwrap();
assert_ct_meta("add_ct_assign_aligned_unsafe", &ct1, expected_log_delta, expected_log_budget);
let ct1 = ct1.normalize(module, &mut scratch.borrow());
assert_decrypt_precision(
"add_ct_assign_aligned_unsafe",
¶ms,
module,
&encoder,
&ct1,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_add_pt_vec_into_aligned_unsafe<BE, F, E>(
params: CKKSTestParams,
module: &Module<BE>,
host_module: &Module<HostBytesBackend>,
) where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let (re2, im2) = test_vector_2::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct1 = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let pt = encode_and_upload_pt(host_module, module, &encoder, params.base2k.into(), params.prec, &re2, &im2);
let (want_re, want_im) = want_add(
&re1,
&im1,
&quantize(&re2, params.prec.log_delta),
&quantize(&im2, params.prec.log_delta),
);
let mut ct_res = UnnormalizedCKKSCiphertext::new(alloc_ct(¶ms, module, params.k));
module
.ckks_add_pt_vec_into_unnormalized(&mut ct_res, &ct1, &pt, &mut scratch.borrow())
.unwrap();
assert_unary_output_meta("add_pt_vec_unsafe", &ct_res, &ct1);
let ct_res = ct_res.normalize(module, &mut scratch.borrow());
assert_decrypt_precision(
"add_pt_vec_unsafe",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_add_const_into_aligned_unsafe<BE, F, E>(
params: CKKSTestParams,
module: &Module<BE>,
host_module: &Module<HostBytesBackend>,
) where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (const_re, const_im) = quantized_const::<F>(ADD_SUB_CONST.0, ADD_SUB_CONST.1, PT_PREC.log_delta);
let (want_re, want_im) = want_add_const(&re1, &im1, const_re, const_im);
let mut ct_res = UnnormalizedCKKSCiphertext::new(alloc_ct(¶ms, module, params.k));
let cst = add_sub_const_pt::<BE, F>(host_module, module, params.base2k.into());
module
.ckks_add_pt_const_into_unnormalized(&mut ct_res, &ct, 0, &cst, 0, &mut scratch.borrow())
.unwrap();
module
.ckks_add_pt_const_assign_unnormalized(&mut ct_res, m, &cst, 1, &mut scratch.borrow())
.unwrap();
assert_unary_output_meta("add_const_into_aligned_unsafe", &ct_res, &ct);
let ct_res = ct_res.normalize(module, &mut scratch.borrow());
assert_decrypt_precision(
"add_const_into_aligned_unsafe",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}