use crate::{CKKSInfos, leveled::api::CKKSImagOps};
use super::helpers::{
TestContextBackend, TestContextModule, TestScalar, alloc_ct, alloc_scratch, assert_ct_meta, assert_decrypt_precision,
assert_unary_output_meta, ckks_encrypt, gen_sk, test_vector_1, want_div_i, want_mul_i,
};
use poulpy_hal::{
api::{NegacyclicFFT, NegacyclicFFTNew, ScratchOwnedBorrow},
layouts::{HostBytesBackend, Module},
};
use crate::{encoding::reim::Encoder, test_suite::CKKSTestParams};
pub fn test_mul_i_aligned<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_mul_i(&re1, &im1);
let mut ct_res = alloc_ct(¶ms, module, params.k);
module.ckks_mul_i_into(&mut ct_res, &ct, &mut scratch.borrow()).unwrap();
assert_unary_output_meta("mul_i", &ct_res, &ct);
assert_decrypt_precision(
"mul_i",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_mul_i_smaller_output<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_mul_i(&re1, &im1);
let mut ct_res = alloc_ct(¶ms, module, params.k - params.base2k - 1);
module.ckks_mul_i_into(&mut ct_res, &ct, &mut scratch.borrow()).unwrap();
assert_unary_output_meta("mul_i smaller_output", &ct_res, &ct);
assert_decrypt_precision(
"mul_i",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_mul_i_assign<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let mut ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_mul_i(&re1, &im1);
let expected_log_delta = ct.log_delta();
let expected_log_budget = ct.log_budget();
module.ckks_mul_i_assign(&mut ct, &mut scratch.borrow()).unwrap();
assert_ct_meta("mul_i_assign", &ct, expected_log_delta, expected_log_budget);
assert_decrypt_precision(
"mul_i_assign",
¶ms,
module,
&encoder,
&ct,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_div_i_aligned<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_div_i(&re1, &im1);
let mut ct_res = alloc_ct(¶ms, module, params.k);
module.ckks_div_i_into(&mut ct_res, &ct, &mut scratch.borrow()).unwrap();
assert_unary_output_meta("div_i", &ct_res, &ct);
assert_decrypt_precision(
"div_i",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_div_i_smaller_output<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_div_i(&re1, &im1);
let mut ct_res = alloc_ct(¶ms, module, params.k - params.base2k - 1);
module.ckks_div_i_into(&mut ct_res, &ct, &mut scratch.borrow()).unwrap();
assert_unary_output_meta("div_i smaller_output", &ct_res, &ct);
assert_decrypt_precision(
"div_i",
¶ms,
module,
&encoder,
&ct_res,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}
pub fn test_div_i_assign<BE, F, E>(params: CKKSTestParams, module: &Module<BE>, host_module: &Module<HostBytesBackend>)
where
BE: TestContextBackend,
Module<BE>: TestContextModule<BE>,
F: TestScalar,
E: NegacyclicFFT<F> + NegacyclicFFTNew<F>,
{
let m = params.n / 2;
let encoder = Encoder::<E>::new(m).unwrap();
let (re1, im1) = test_vector_1::<F>(m);
let sk = gen_sk(¶ms, module, host_module, [0u8; 32]);
let mut scratch = alloc_scratch(¶ms, module);
let mut ct = ckks_encrypt(
¶ms,
module,
host_module,
&encoder,
&sk,
params.k,
&re1,
&im1,
&mut scratch.borrow(),
);
let (want_re, want_im) = want_div_i(&re1, &im1);
let expected_log_delta = ct.log_delta();
let expected_log_budget = ct.log_budget();
module.ckks_div_i_assign(&mut ct, &mut scratch.borrow()).unwrap();
assert_ct_meta("div_i_assign", &ct, expected_log_delta, expected_log_budget);
assert_decrypt_precision(
"div_i_assign",
¶ms,
module,
&encoder,
&ct,
&sk,
&want_re,
&want_im,
&mut scratch.borrow(),
);
}