poulpy-ckks 0.6.0

A backend-agnostic crate implementing the CKKS FHE scheme
Documentation
use anyhow::Result;
use poulpy_core::{
    GLWEAdd, GLWENormalize, GLWEShift,
    layouts::{Base2K, GLWEPlaintext, GLWEToBackendMut, GLWEToBackendRef, LWEInfos},
};
use poulpy_hal::{
    api::{VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshTmpBytes},
    layouts::{Backend, ScratchArena, VecZnx},
};

use crate::{
    CKKSInfos, CKKSMeta, SetCKKSInfos, checked_log_budget_sub, ckks_offset_binary, ckks_offset_unary,
    default::CKKSPlaintextDefault,
    layouts::{CKKSPlaintext, CKKSPlaintextVecHostCodec},
};

pub(crate) fn ckks_one_pt<BE>(base2k: Base2K) -> Result<CKKSPlaintext<BE::OwnedBuf>>
where
    BE: Backend,
{
    let meta = CKKSMeta {
        log_delta: 1,
        log_budget: 0,
    };

    let mut host_pt = CKKSPlaintext::from_inner(
        GLWEPlaintext::alloc_with_meta(1usize.into(), base2k, meta.min_k(base2k)),
        meta,
    );
    host_pt.encode_host_floats(&[1.0f64])?;

    let shape = host_pt.inner.data.shape();
    let backend_inner = GLWEPlaintext {
        data: VecZnx::from_data_with_max_size(
            BE::from_host_bytes(host_pt.inner.data.data.as_ref()),
            shape.n(),
            shape.cols(),
            shape.size(),
            shape.max_size(),
        ),
        base2k,
    };
    Ok(CKKSPlaintext::from_inner(backend_inner, meta))
}

pub trait CKKSAddDefault<BE: Backend> {
    fn ckks_add_tmp_bytes_default(&self) -> usize
    where
        Self: GLWEShift<BE> + GLWENormalize<BE>,
    {
        self.glwe_shift_tmp_bytes().max(self.glwe_normalize_tmp_bytes())
    }

    fn ckks_add_pt_vec_tmp_bytes_default(&self) -> usize
    where
        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
    {
        self.glwe_shift_tmp_bytes()
            .max(self.vec_znx_rsh_tmp_bytes())
            .max(self.glwe_normalize_tmp_bytes())
    }

    fn ckks_add_pt_const_tmp_bytes_default(&self) -> usize
    where
        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
    {
        self.glwe_shift_tmp_bytes()
            .max(self.glwe_normalize_tmp_bytes())
            .max(self.vec_znx_rsh_tmp_bytes())
    }

    fn ckks_add_into_default<Dst, A, B>(&self, dst: &mut Dst, a: &A, b: &B, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        self.ckks_add_into_unsafe_default(dst, a, b, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_into_unsafe_default<Dst, A, B>(
        &self,
        dst: &mut Dst,
        a: &A,
        b: &B,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: GLWEAdd<BE> + GLWEShift<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        let offset = ckks_offset_binary(dst, a, b);

        if offset == 0 && a.log_budget() == b.log_budget() {
            self.glwe_add_into(dst, a, b);
        } else if a.log_budget() <= b.log_budget() {
            self.glwe_lsh(dst, a, offset, scratch);
            self.glwe_lsh_add(dst, b, b.log_budget() - a.log_budget() + offset, scratch);
        } else {
            self.glwe_lsh(dst, b, offset, scratch);
            self.glwe_lsh_add(dst, a, a.log_budget() - b.log_budget() + offset, scratch);
        }

        let log_budget = checked_log_budget_sub("add", a.log_budget().min(b.log_budget()), offset)?;
        dst.set_log_delta(a.log_delta().min(b.log_delta()));
        dst.set_log_budget(log_budget);
        Ok(())
    }

    fn ckks_add_assign_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + CKKSInfos,
    {
        self.ckks_add_assign_unsafe_default(dst, a, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_assign_unsafe_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: GLWEAdd<BE> + GLWEShift<BE>,
        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + CKKSInfos,
    {
        let dst_log_budget = dst.log_budget();

        if dst_log_budget < a.log_budget() {
            self.glwe_lsh_add(dst, a, a.log_budget() - dst_log_budget, scratch);
        } else if dst_log_budget > a.log_budget() {
            self.glwe_lsh_assign(dst, dst_log_budget - a.log_budget(), scratch);
            self.glwe_add_assign(dst, a);
        } else {
            self.glwe_add_assign(dst, a);
        }

        dst.set_log_budget(dst_log_budget.min(a.log_budget()));
        dst.set_log_delta(dst.log_delta().min(a.log_delta()));
        Ok(())
    }

    fn ckks_add_one_assign_default<Dst>(&self, dst: &mut Dst, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
    {
        let one = ckks_one_pt::<BE>(dst.base2k())?;
        self.ckks_add_pt_const_assign_default(dst, 0, &one, 0, scratch)
    }

    fn ckks_add_pt_vec_into_default<Dst, A, P>(
        &self,
        dst: &mut Dst,
        a: &A,
        pt: &P,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        self.ckks_add_pt_vec_into_unsafe_default(dst, a, pt, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_pt_vec_into_unsafe_default<Dst, A, P>(
        &self,
        dst: &mut Dst,
        a: &A,
        pt: &P,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        let offset = ckks_offset_unary(dst, a);
        self.glwe_lsh(dst, a, offset, scratch);
        dst.set_meta(a.meta());
        dst.set_log_budget(checked_log_budget_sub("add_pt_vec", a.log_budget(), offset)?);
        self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
        Ok(())
    }

    fn ckks_add_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
    where
        Self: VecZnxRshAddIntoBackend<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_pt_vec_assign_unsafe_default<Dst, P>(
        &self,
        dst: &mut Dst,
        pt: &P,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshAddIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        CKKSPlaintextDefault::ckks_add_pt_vec_into_default(self, dst, pt, scratch)?;
        Ok(())
    }

    fn ckks_add_pt_const_into_default<Dst, A, P>(
        &self,
        dst: &mut Dst,
        a: &A,
        dst_coeff: usize,
        cst: &P,
        const_coeff: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        self.ckks_add_pt_const_into_unsafe_default(dst, a, dst_coeff, cst, const_coeff, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_pt_const_into_unsafe_default<Dst, A, P>(
        &self,
        dst: &mut Dst,
        a: &A,
        dst_coeff: usize,
        cst: &P,
        const_coeff: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: GLWEShift<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        let offset = ckks_offset_unary(dst, a);
        self.glwe_lsh(dst, a, offset, scratch);
        dst.set_meta(a.meta());
        dst.set_log_budget(checked_log_budget_sub("add_const", a.log_budget(), offset)?);
        self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)
    }

    fn ckks_add_pt_const_assign_default<Dst, P>(
        &self,
        dst: &mut Dst,
        dst_coeff: usize,
        cst: &P,
        const_coeff: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)?;
        self.glwe_normalize_assign(dst, scratch);
        Ok(())
    }

    fn ckks_add_pt_const_assign_unsafe_default<Dst, P>(
        &self,
        dst: &mut Dst,
        dst_coeff: usize,
        cst: &P,
        const_coeff: usize,
        scratch: &mut ScratchArena<'_, BE>,
    ) -> Result<()>
    where
        Self: VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
    {
        CKKSPlaintextDefault::ckks_add_pt_const_into_default(self, dst, dst_coeff, cst, const_coeff, scratch)?;
        Ok(())
    }
}