use std::{
fmt,
marker::PhantomData,
ops::{Deref, DerefMut},
};
use anyhow::Result;
use poulpy_core::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToBackendMut, GLWEToBackendRef, GLWEViewMut, LWEInfos, Rank};
use poulpy_core::{GLWENormalize, ScratchArenaTakeCore};
use poulpy_hal::layouts::{Backend, Data, HostBackend, HostDataRef, Module, ScratchArena};
use crate::{CKKSInfos, CKKSMeta, SetCKKSInfos, error::CKKSCompositionError, layouts::CKKSModuleAlloc};
mod sealed {
pub trait Sealed {}
}
pub struct Normalized;
pub struct Unnormalized;
impl sealed::Sealed for Normalized {}
impl sealed::Sealed for Unnormalized {}
pub trait CKKSNormalizationState: sealed::Sealed {}
impl CKKSNormalizationState for Normalized {}
impl CKKSNormalizationState for Unnormalized {}
pub struct CKKSCiphertext<D: Data, S: CKKSNormalizationState = Normalized> {
pub(crate) inner: GLWE<D>,
pub(crate) meta: CKKSMeta,
_state: PhantomData<S>,
}
impl<D: Data, S: CKKSNormalizationState> CKKSCiphertext<D, S> {
pub(crate) fn from_inner(inner: GLWE<D>, meta: CKKSMeta) -> Self {
Self {
inner,
meta,
_state: PhantomData,
}
}
pub fn to_host_owned<BE>(&self) -> CKKSCiphertext<Vec<u8>, S>
where
BE: Backend<OwnedBuf = D>,
{
CKKSCiphertext::<Vec<u8>, S>::from_inner(self.inner.to_host_owned::<BE>(), self.meta)
}
pub fn display_host<BE>(&self) -> String
where
BE: Backend<OwnedBuf = D>,
{
self.to_host_owned::<BE>().to_string()
}
pub fn to_ref<BE: Backend>(&self) -> GLWE<BE::BufRef<'_>>
where
GLWE<D>: GLWEToBackendRef<BE>,
{
GLWEToBackendRef::to_backend_ref(&self.inner)
}
pub fn to_mut<BE: Backend>(&mut self) -> GLWE<BE::BufMut<'_>>
where
GLWE<D>: GLWEToBackendMut<BE>,
{
GLWEToBackendMut::to_backend_mut(&mut self.inner)
}
pub fn set_meta_checked(&mut self, meta: CKKSMeta) -> Result<()> {
anyhow::ensure!(
meta.effective_k() <= self.max_k().as_usize(),
CKKSCompositionError::LimbReallocationShrinksBelowMetadata {
max_k: self.max_k().as_usize(),
log_delta: meta.log_delta(),
base2k: self.base2k().as_usize(),
requested_limbs: self.size(),
}
);
self.meta = meta;
Ok(())
}
}
impl<D: Data, S: CKKSNormalizationState> Deref for CKKSCiphertext<D, S> {
type Target = GLWE<D>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<D: Data, S: CKKSNormalizationState> DerefMut for CKKSCiphertext<D, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<D: Data, S: CKKSNormalizationState> LWEInfos for CKKSCiphertext<D, S> {
fn base2k(&self) -> Base2K {
self.inner.base2k()
}
fn n(&self) -> Degree {
self.inner.n()
}
fn size(&self) -> usize {
self.inner.size()
}
}
impl<D: Data, S: CKKSNormalizationState> GLWEInfos for CKKSCiphertext<D, S> {
fn rank(&self) -> Rank {
self.inner.rank()
}
}
impl<D: Data, S: CKKSNormalizationState> CKKSInfos for CKKSCiphertext<D, S> {
fn meta(&self) -> CKKSMeta {
self.meta
}
fn log_delta(&self) -> usize {
self.meta.log_delta()
}
fn log_budget(&self) -> usize {
self.meta.log_budget()
}
}
impl<D: Data, S: CKKSNormalizationState> SetCKKSInfos for CKKSCiphertext<D, S> {
fn set_meta(&mut self, meta: CKKSMeta) {
self.meta = meta;
}
}
impl<D: HostDataRef, S: CKKSNormalizationState> fmt::Display for CKKSCiphertext<D, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
impl<BE: Backend, D: Data, S: CKKSNormalizationState> GLWEToBackendRef<BE> for CKKSCiphertext<D, S>
where
GLWE<D>: GLWEToBackendRef<BE>,
{
fn to_backend_ref(&self) -> GLWE<BE::BufRef<'_>> {
GLWEToBackendRef::to_backend_ref(&self.inner)
}
}
impl<BE: Backend, D: Data, S: CKKSNormalizationState> GLWEToBackendMut<BE> for CKKSCiphertext<D, S>
where
GLWE<D>: GLWEToBackendMut<BE>,
{
fn to_backend_mut(&mut self) -> GLWE<BE::BufMut<'_>> {
GLWEToBackendMut::to_backend_mut(&mut self.inner)
}
}
pub struct CKKSCiphertextViewMut<'a, BE: Backend + 'a> {
inner: GLWEViewMut<'a, BE>,
meta: CKKSMeta,
}
impl<'a, BE: Backend + 'a> CKKSCiphertextViewMut<'a, BE> {
pub(crate) fn from_inner(inner: GLWEViewMut<'a, BE>, meta: CKKSMeta) -> Self {
Self { inner, meta }
}
}
impl<'a, BE: Backend + 'a> Deref for CKKSCiphertextViewMut<'a, BE> {
type Target = GLWEViewMut<'a, BE>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<'a, BE: Backend + 'a> DerefMut for CKKSCiphertextViewMut<'a, BE> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<'a, BE: Backend + 'a> LWEInfos for CKKSCiphertextViewMut<'a, BE> {
fn base2k(&self) -> Base2K {
self.inner.base2k()
}
fn n(&self) -> Degree {
self.inner.n()
}
fn size(&self) -> usize {
self.inner.size()
}
}
impl<'a, BE: Backend + 'a> GLWEInfos for CKKSCiphertextViewMut<'a, BE> {
fn rank(&self) -> Rank {
self.inner.rank()
}
}
impl<'a, BE: Backend + 'a> CKKSInfos for CKKSCiphertextViewMut<'a, BE> {
fn meta(&self) -> CKKSMeta {
self.meta
}
fn log_delta(&self) -> usize {
self.meta.log_delta()
}
fn log_budget(&self) -> usize {
self.meta.log_budget()
}
}
impl<'a, BE: Backend + 'a> SetCKKSInfos for CKKSCiphertextViewMut<'a, BE> {
fn set_meta(&mut self, meta: CKKSMeta) {
self.meta = meta;
}
}
impl<'a, BE: Backend + 'a> GLWEToBackendRef<BE> for CKKSCiphertextViewMut<'a, BE> {
fn to_backend_ref(&self) -> GLWE<BE::BufRef<'_>> {
self.inner.to_backend_ref()
}
}
impl<'a, BE: Backend + 'a> GLWEToBackendMut<BE> for CKKSCiphertextViewMut<'a, BE> {
fn to_backend_mut(&mut self) -> GLWE<BE::BufMut<'_>> {
self.inner.to_backend_mut()
}
}
pub trait ScratchArenaTakeCKKS<'a, BE: Backend>: ScratchArenaTakeCore<'a, BE> + Sized {
fn take_ckks_ciphertext_scratch<I>(self, infos: &I, meta: CKKSMeta) -> (CKKSCiphertextViewMut<'a, BE>, Self)
where
BE: 'a,
I: GLWEInfos,
{
let (inner, scratch) = self.take_glwe_scratch(infos);
(CKKSCiphertextViewMut::from_inner(inner, meta), scratch)
}
fn take_ckks_ciphertext_like_scratch<C>(self, ct: &C) -> (CKKSCiphertextViewMut<'a, BE>, Self)
where
BE: 'a,
C: GLWEInfos + CKKSInfos,
{
self.take_ckks_ciphertext_scratch(ct, ct.meta())
}
fn take_unnormalized_ckks_ciphertext_scratch<I>(
self,
infos: &I,
meta: CKKSMeta,
) -> (UnnormalizedCKKSCiphertext<BE::BufMut<'a>>, Self)
where
BE: 'a,
I: GLWEInfos,
{
let (inner, scratch) = self.take_glwe_scratch(infos);
(UnnormalizedCKKSCiphertext::from_inner(inner.into_inner(), meta), scratch)
}
fn take_unnormalized_ckks_ciphertext_like_scratch<C>(self, ct: &C) -> (UnnormalizedCKKSCiphertext<BE::BufMut<'a>>, Self)
where
BE: 'a,
C: GLWEInfos + CKKSInfos,
{
self.take_unnormalized_ckks_ciphertext_scratch(ct, ct.meta())
}
}
impl<'a, BE, T> ScratchArenaTakeCKKS<'a, BE> for T
where
BE: Backend + 'a,
T: ScratchArenaTakeCore<'a, BE>,
{
}
pub trait CKKSMaintainOps {
fn ckks_reallocate_limbs_checked(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()>;
fn ckks_compact_limbs(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()>;
fn ckks_compact_limbs_copy<D>(&self, ct: &CKKSCiphertext<D>) -> Result<CKKSCiphertext<Vec<u8>>>
where
D: HostDataRef;
}
#[doc(hidden)]
pub trait CKKSMaintainOpsDefault<BE: Backend> {
fn ckks_reallocate_limbs_checked_default(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()> {
let base2k = ct.base2k().as_usize();
let required_limbs = ct.effective_k().div_ceil(base2k);
anyhow::ensure!(
size >= required_limbs,
CKKSCompositionError::LimbReallocationShrinksBelowMetadata {
max_k: ct.max_k().as_usize(),
log_delta: ct.log_delta(),
base2k,
requested_limbs: size,
}
);
ct.data_mut().reallocate_limbs(size);
Ok(())
}
fn ckks_compact_limbs_default(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()> {
let size = ct.effective_k().div_ceil(ct.base2k().as_usize());
self.ckks_reallocate_limbs_checked_default(ct, size)?;
Ok(())
}
}
#[macro_export]
macro_rules! impl_ckks_maintain_ops_defaults {
($be:ty) => {
impl $crate::layouts::ciphertext::CKKSMaintainOpsDefault<$be> for ::poulpy_hal::layouts::Module<$be> {}
};
}
pub use crate::impl_ckks_maintain_ops_defaults;
impl<BE: Backend> CKKSMaintainOps for Module<BE>
where
BE: HostBackend<OwnedBuf = Vec<u8>>,
Module<BE>: CKKSMaintainOpsDefault<BE> + CKKSModuleAlloc<BE>,
{
fn ckks_reallocate_limbs_checked(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()> {
self.ckks_reallocate_limbs_checked_default(ct, size)
}
fn ckks_compact_limbs(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()> {
self.ckks_compact_limbs_default(ct)
}
fn ckks_compact_limbs_copy<D>(&self, ct: &CKKSCiphertext<D>) -> Result<CKKSCiphertext<Vec<u8>>>
where
D: HostDataRef,
{
let size = ct.effective_k().div_ceil(ct.base2k().as_usize());
let mut compact = self.ckks_ciphertext_alloc_from_infos(ct);
compact.meta = ct.meta();
self.ckks_reallocate_limbs_checked_default(&mut compact, size)?;
let dst_len = compact.data().data.len();
compact.data_mut().data.copy_from_slice(&ct.data().data.as_ref()[..dst_len]);
Ok(compact)
}
}
pub type UnnormalizedCKKSCiphertext<D> = CKKSCiphertext<D, Unnormalized>;
impl<D: Data> CKKSCiphertext<D, Unnormalized> {
pub fn new(ct: CKKSCiphertext<D>) -> Self {
Self::from_inner(ct.inner, ct.meta)
}
pub fn normalize<M, BE>(self, module: &M, scratch: &mut ScratchArena<'_, BE>) -> CKKSCiphertext<D>
where
BE: Backend,
M: GLWENormalize<BE>,
GLWE<D>: GLWEToBackendMut<BE>,
{
let mut normalized = CKKSCiphertext::<D>::from_inner(self.inner, self.meta);
module.glwe_normalize_assign(&mut normalized, scratch);
normalized
}
}
pub struct UnnormalizedCKKSCiphertextRefMut<'a, D: Data> {
pub(crate) inner: &'a mut CKKSCiphertext<D>,
}
impl<'a, D: Data> UnnormalizedCKKSCiphertextRefMut<'a, D> {
pub(crate) fn new(inner: &'a mut CKKSCiphertext<D>) -> Self {
Self { inner }
}
pub(crate) fn normalize<M, BE>(self, module: &M, scratch: &mut ScratchArena<'_, BE>)
where
BE: Backend,
M: GLWENormalize<BE>,
CKKSCiphertext<D>: GLWEToBackendMut<BE>,
{
module.glwe_normalize_assign(self.inner, scratch);
}
}
pub(crate) trait CKKSOffset: LWEInfos + CKKSInfos {
fn offset_unary<A>(&self, a: &A) -> usize
where
A: LWEInfos + CKKSInfos,
{
crate::ckks_offset_unary(self, a)
}
fn offset_binary<A, B>(&self, a: &A, b: &B) -> usize
where
A: LWEInfos + CKKSInfos,
B: LWEInfos + CKKSInfos,
{
crate::ckks_offset_binary(self, a, b)
}
}
impl<T> CKKSOffset for T where T: LWEInfos + CKKSInfos {}