poulpy-core 0.6.0

A backend-agnostic crate implementing Module-LWE-based encryption and arithmetic
Documentation
use poulpy_hal::{
    api::{
        ModuleN, ScratchArenaTakeBasic, SvpApplyDftToDftAssign, VecZnxBigAddAssign, VecZnxBigBytesOf, VecZnxBigFromSmallBackend,
        VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
    },
    layouts::{Backend, ScratchArena, VecZnxBigToBackendMut, VecZnxBigToBackendRef, VecZnxDftToBackendMut},
};

pub use crate::api::GLWEDecrypt;
use crate::layouts::{
    GLWEBackendMut, GLWEBackendRef, GLWEInfos, GLWEToBackendMut, GLWEToBackendRef, LWEInfos, SetLWEInfos,
    prepared::{GLWESecretPreparedBackendRef, GLWESecretPreparedToBackendRef},
};

pub fn glwe_decrypt_tmp_bytes_default<M, BE: Backend, A>(module: &M, infos: &A) -> usize
where
    M: ModuleN + VecZnxDftBytesOf + VecZnxBigBytesOf + VecZnxBigNormalizeTmpBytes,
    A: GLWEInfos,
{
    let size: usize = infos.size();
    assert_eq!(module.n() as u32, infos.n());

    let lvl_0: usize = module.bytes_of_vec_znx_big(1, size);
    let lvl_1: usize = (module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size))
        .max(module.vec_znx_big_normalize_tmp_bytes());

    lvl_0 + lvl_1
}

pub fn glwe_decrypt_default<M, BE: Backend, R, P, S>(module: &M, res: &R, pt: &mut P, sk: &S, scratch: &mut ScratchArena<'_, BE>)
where
    M: ModuleN
        + VecZnxDftBytesOf
        + VecZnxBigBytesOf
        + VecZnxBigFromSmallBackend<BE>
        + VecZnxDftApply<BE>
        + SvpApplyDftToDftAssign<BE>
        + VecZnxIdftApplyTmpA<BE>
        + VecZnxBigAddAssign<BE>
        + VecZnxBigNormalize<BE>
        + VecZnxBigNormalizeTmpBytes,
    R: GLWEToBackendRef<BE> + GLWEInfos,
    P: GLWEToBackendMut<BE> + GLWEInfos + SetLWEInfos,
    S: GLWESecretPreparedToBackendRef<BE> + GLWEInfos,
{
    let res_backend = res.to_backend_ref();
    let mut pt_backend = pt.to_backend_mut();
    let sk_backend = sk.to_backend_ref();

    glwe_decrypt_backend_inner(module, &res_backend, &mut pt_backend, &sk_backend, scratch);
}

pub(crate) fn glwe_decrypt_backend_inner<'arena, 'scratch, M, BE: Backend>(
    module: &M,
    res: &GLWEBackendRef<'_, BE>,
    pt: &mut GLWEBackendMut<'_, BE>,
    sk: &GLWESecretPreparedBackendRef<'_, BE>,
    scratch: &'scratch mut ScratchArena<'arena, BE>,
) where
    M: ModuleN
        + VecZnxDftBytesOf
        + VecZnxBigBytesOf
        + VecZnxBigFromSmallBackend<BE>
        + VecZnxDftApply<BE>
        + SvpApplyDftToDftAssign<BE>
        + VecZnxIdftApplyTmpA<BE>
        + VecZnxBigAddAssign<BE>
        + VecZnxBigNormalize<BE>
        + VecZnxBigNormalizeTmpBytes,
{
    #[cfg(debug_assertions)]
    {
        assert_eq!(res.rank(), sk.rank());
        assert_eq!(res.n(), sk.n());
        assert_eq!(pt.n(), sk.n());
    }
    assert!(
        scratch.available() >= glwe_decrypt_tmp_bytes_default::<M, BE, _>(module, res),
        "scratch.available(): {} < GLWEDecrypt::glwe_decrypt_tmp_bytes: {}",
        scratch.available(),
        glwe_decrypt_tmp_bytes_default::<M, BE, _>(module, res)
    );

    let cols: usize = (res.rank() + 1).into();
    let (mut c0_big, mut scratch_1) = scratch.borrow().take_vec_znx_big_scratch(module, 1, res.size());
    module.vec_znx_big_from_small_backend(&mut c0_big, 0, &res.data, 0);

    for i in 1..cols {
        let (mut ci_dft, scratch_2) = scratch_1.borrow().take_vec_znx_dft_scratch(module, 1, res.size());
        module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i);
        {
            let mut ci_dft_backend = ci_dft.to_backend_mut();
            module.svp_apply_dft_to_dft_assign(&mut ci_dft_backend, 0, &sk.data, i - 1);
        }
        let (mut ci_big, _) = scratch_2.take_vec_znx_big_scratch(module, 1, res.size());
        {
            let mut ci_big_backend = ci_big.to_backend_mut();
            let mut ci_dft_backend = ci_dft.to_backend_mut();
            module.vec_znx_idft_apply_tmpa(&mut ci_big_backend, 0, &mut ci_dft_backend, 0);
        }
        let ci_big_ref = ci_big.to_backend_ref();
        module.vec_znx_big_add_assign(&mut c0_big, 0, &ci_big_ref, 0);
    }

    let c0_big_ref = c0_big.to_backend_ref();
    let pt_base2k = pt.base2k();
    let _ = scratch_1.apply_mut(|scratch| {
        module.vec_znx_big_normalize(
            &mut pt.data,
            pt_base2k.into(),
            0,
            0,
            &c0_big_ref,
            res.base2k().into(),
            0,
            scratch,
        )
    });
}