poulpy-ckks 0.6.0

A backend-agnostic crate implementing the CKKS FHE scheme
Documentation
use anyhow::Result;
use poulpy_core::layouts::{GLWEInfos, GLWEToBackendMut, LWEInfos};
use poulpy_hal::{
    api::{
        VecZnxLshBackend, VecZnxLshTmpBytes, VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshBackend,
        VecZnxRshSubBackend, VecZnxRshSubCoeffIntoBackend, VecZnxRshTmpBytes,
    },
    layouts::{Backend, ScratchArena},
};

use crate::GLWEToBackendRef;

use crate::{
    CKKSInfos, CKKSMeta, SetCKKSInfos, ensure_base2k_match, ensure_plaintext_alignment, ensure_plaintext_coeff_in_range,
    ensure_plaintext_degree_match,
};

pub trait CKKSPlaintextDefault<BE: Backend> {
    fn ckks_add_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: VecZnxRshAddIntoBackend<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        const OP: &str = "ckks_add_pt_vec";
        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
        ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
        let base2k = ct.base2k().as_usize();
        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
        self.vec_znx_rsh_add_into_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
        Ok(())
    }

    fn ckks_add_pt_const_into_default<Dst, A>(
        &self,
        ct: &mut Dst,
        coeff_ct: usize,
        pt: &A,
        coeff_pt: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshAddCoeffIntoBackend<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        const OP: &str = "ckks_add_pt_const";
        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
        ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
        ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
        let base2k = ct.base2k().as_usize();
        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
        self.vec_znx_rsh_add_coeff_into_backend(
            base2k,
            offset,
            ct_ref.data_mut(),
            0,
            pt_ref.data(),
            0,
            coeff_pt,
            coeff_ct,
            scratch,
        );

        Ok(())
    }

    fn ckks_sub_pt_const_into_default<Dst, A>(
        &self,
        ct: &mut Dst,
        coeff_ct: usize,
        pt: &A,
        coeff_pt: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshSubCoeffIntoBackend<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        const OP: &str = "ckks_sub_pt_const";
        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
        ensure_plaintext_coeff_in_range(OP, "ciphertext", coeff_ct, ct.n().as_usize())?;
        ensure_plaintext_coeff_in_range(OP, "plaintext", coeff_pt, pt.n().as_usize())?;
        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
        let base2k = ct.base2k().as_usize();
        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
        self.vec_znx_rsh_sub_coeff_into_backend(
            base2k,
            offset,
            ct_ref.data_mut(),
            0,
            pt_ref.data(),
            0,
            coeff_pt,
            coeff_ct,
            scratch,
        );

        Ok(())
    }

    fn ckks_sub_pt_vec_into_default<Dst, A>(&self, ct: &mut Dst, pt: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: VecZnxRshSubBackend<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        const OP: &str = "ckks_sub_pt_vec";
        ensure_base2k_match(OP, ct.base2k().as_usize(), pt.base2k().as_usize())?;
        ensure_plaintext_degree_match(OP, ct.n().as_usize(), pt.n().as_usize())?;
        let offset = ensure_plaintext_alignment(OP, ct.log_budget(), pt.log_delta(), pt.max_k().as_usize())?;
        let base2k = ct.base2k().as_usize();
        let mut ct_ref = GLWEToBackendMut::to_backend_mut(ct);
        let pt_ref = GLWEToBackendRef::to_backend_ref(pt);
        self.vec_znx_rsh_sub_backend(base2k, offset, ct_ref.data_mut(), 0, pt_ref.data(), 0, scratch);
        Ok(())
    }

    fn ckks_extract_pt_tmp_bytes_default(&self) -> usize
    where
        Self: VecZnxLshTmpBytes + VecZnxRshTmpBytes,
    {
        self.vec_znx_rsh_tmp_bytes().max(self.vec_znx_lsh_tmp_bytes())
    }

    fn ckks_extract_pt_default<D, S>(&self, dst: &mut D, src: &S, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos + CKKSInfos,
        Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
    {
        self.ckks_extract_pt_with_meta_default(dst, src, src.meta(), scratch)
    }

    fn ckks_extract_pt_with_meta_default<D, S>(
        &self,
        dst: &mut D,
        src: &S,
        src_meta: CKKSMeta,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        D: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        S: GLWEToBackendRef<BE> + GLWEInfos + LWEInfos,
        Self: VecZnxLshBackend<BE> + VecZnxRshBackend<BE>,
    {
        ensure_base2k_match("ckks_extract_pt", src.base2k().as_usize(), dst.base2k().as_usize())?;
        let available = src_meta.log_budget() + dst.log_delta();
        if available < dst.effective_k() {
            return Err(crate::CKKSCompositionError::PlaintextAlignmentImpossible {
                op: "ckks_extract_pt",
                ct_log_budget: src_meta.log_budget(),
                pt_log_delta: dst.log_delta(),
                pt_k: dst.effective_k(),
            }
            .into());
        }
        let dst_k = dst.max_k().as_usize();
        let dst_base2k: usize = dst.base2k().into();
        let mut dst_ref = GLWEToBackendMut::to_backend_mut(dst);
        let src_ref = GLWEToBackendRef::to_backend_ref(src);

        if available < dst_k {
            self.vec_znx_rsh_backend(
                dst_base2k,
                dst_k - available,
                dst_ref.data_mut(),
                0,
                src_ref.data(),
                0,
                scratch,
            );
        } else if available > dst_k {
            self.vec_znx_lsh_backend(
                dst_base2k,
                available - dst_k,
                dst_ref.data_mut(),
                0,
                src_ref.data(),
                0,
                scratch,
            );
        } else {
            self.vec_znx_rsh_backend(dst_base2k, 0, dst_ref.data_mut(), 0, src_ref.data(), 0, scratch);
        }
        Ok(())
    }
}