poulpy-core 0.6.0

A backend-agnostic crate implementing Module-LWE-based encryption and arithmetic
Documentation
use poulpy_hal::{
    api::{VecZnxNormalize, VecZnxNormalizeTmpBytes},
    layouts::{Backend, HostBackend, HostDataMut, HostDataRef, ScratchArena, ZnxView, ZnxViewMut},
};

use crate::{
    ScratchArenaTakeCore,
    layouts::{
        LWEInfos, LWEPlaintext, LWEPlaintextToBackendMut, LWEPlaintextToBackendRef, LWESecretToBackendRef, LWEToBackendRef,
        SetLWEInfos,
    },
};

pub fn lwe_decrypt_tmp_bytes_default<M, BE: Backend, A>(module: &M, infos: &A) -> usize
where
    M: VecZnxNormalizeTmpBytes,
    A: LWEInfos,
{
    let lvl_0: usize = LWEPlaintext::bytes_of(infos.size());
    let lvl_1: usize = module.vec_znx_normalize_tmp_bytes();

    lvl_0 + lvl_1
}

pub fn lwe_decrypt_default<M, BE, R, P, S>(module: &M, res: &R, pt: &mut P, sk: &S, scratch: &mut ScratchArena<'_, BE>)
where
    M: VecZnxNormalize<BE> + VecZnxNormalizeTmpBytes,
    R: LWEToBackendRef<BE> + LWEInfos,
    P: LWEPlaintextToBackendMut<BE> + SetLWEInfos + LWEInfos,
    S: LWESecretToBackendRef<BE> + LWEInfos,
    BE: Backend + HostBackend,
    for<'a> BE::BufMut<'a>: HostDataMut,
    for<'a> BE::BufRef<'a>: HostDataRef,
{
    let res = res.to_backend_ref();
    let sk = sk.to_backend_ref();

    #[cfg(debug_assertions)]
    {
        assert_eq!(res.n(), sk.n());
    }
    assert!(
        scratch.available() >= lwe_decrypt_tmp_bytes_default::<M, BE, _>(module, &res),
        "scratch.available(): {} < LWEDecrypt::lwe_decrypt_tmp_bytes: {}",
        scratch.available(),
        lwe_decrypt_tmp_bytes_default::<M, BE, _>(module, &res)
    );

    let scratch = scratch.borrow();

    let (mut tmp, mut scratch_1) = scratch.take_lwe_plaintext_scratch(&res);
    for i in 0..res.size() {
        tmp.data.at_mut(0, i)[0] = res.body.at(0, i)[0]
            + res
                .mask
                .at(0, i)
                .iter()
                .zip(sk.data.at(0, 0))
                .map(|(x, y)| x * y)
                .sum::<i64>();
    }

    let pt_base2k = pt.base2k().into();
    let res_base2k = res.base2k().into();
    let mut pt = pt.to_backend_mut();
    let tmp_ref = tmp.to_backend_ref();
    module.vec_znx_normalize(&mut pt.data, pt_base2k, 0, 0, &tmp_ref.data, res_base2k, 0, &mut scratch_1);
}