1use std::{
7 fmt,
8 marker::PhantomData,
9 ops::{Deref, DerefMut},
10};
11
12use anyhow::Result;
13use poulpy_core::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToBackendMut, GLWEToBackendRef, GLWEViewMut, LWEInfos, Rank};
14use poulpy_core::{GLWENormalize, ScratchArenaTakeCore};
15use poulpy_hal::layouts::{Backend, Data, HostBackend, HostDataRef, Module, ScratchArena};
16
17use crate::{CKKSInfos, CKKSMeta, SetCKKSInfos, error::CKKSCompositionError, layouts::CKKSModuleAlloc};
18
19mod sealed {
20 pub trait Sealed {}
21}
22
23pub struct Normalized;
25
26pub struct Unnormalized;
28
29impl sealed::Sealed for Normalized {}
30impl sealed::Sealed for Unnormalized {}
31
32pub trait CKKSNormalizationState: sealed::Sealed {}
34
35impl CKKSNormalizationState for Normalized {}
36impl CKKSNormalizationState for Unnormalized {}
37
38pub struct CKKSCiphertext<D: Data, S: CKKSNormalizationState = Normalized> {
43 pub(crate) inner: GLWE<D>,
45 pub(crate) meta: CKKSMeta,
47 _state: PhantomData<S>,
48}
49
50impl<D: Data, S: CKKSNormalizationState> CKKSCiphertext<D, S> {
51 pub(crate) fn from_inner(inner: GLWE<D>, meta: CKKSMeta) -> Self {
52 Self {
53 inner,
54 meta,
55 _state: PhantomData,
56 }
57 }
58
59 pub fn to_host_owned<BE>(&self) -> CKKSCiphertext<Vec<u8>, S>
61 where
62 BE: Backend<OwnedBuf = D>,
63 {
64 CKKSCiphertext::<Vec<u8>, S>::from_inner(self.inner.to_host_owned::<BE>(), self.meta)
65 }
66
67 pub fn display_host<BE>(&self) -> String
69 where
70 BE: Backend<OwnedBuf = D>,
71 {
72 self.to_host_owned::<BE>().to_string()
73 }
74
75 pub fn to_ref<BE: Backend>(&self) -> GLWE<BE::BufRef<'_>>
76 where
77 GLWE<D>: GLWEToBackendRef<BE>,
78 {
79 GLWEToBackendRef::to_backend_ref(&self.inner)
80 }
81
82 pub fn to_mut<BE: Backend>(&mut self) -> GLWE<BE::BufMut<'_>>
83 where
84 GLWE<D>: GLWEToBackendMut<BE>,
85 {
86 GLWEToBackendMut::to_backend_mut(&mut self.inner)
87 }
88
89 pub fn set_meta_checked(&mut self, meta: CKKSMeta) -> Result<()> {
95 anyhow::ensure!(
96 meta.effective_k() <= self.max_k().as_usize(),
97 CKKSCompositionError::LimbReallocationShrinksBelowMetadata {
98 max_k: self.max_k().as_usize(),
99 log_delta: meta.log_delta(),
100 base2k: self.base2k().as_usize(),
101 requested_limbs: self.size(),
102 }
103 );
104 self.meta = meta;
105 Ok(())
106 }
107}
108
109impl<D: Data, S: CKKSNormalizationState> Deref for CKKSCiphertext<D, S> {
110 type Target = GLWE<D>;
111
112 fn deref(&self) -> &Self::Target {
113 &self.inner
114 }
115}
116
117impl<D: Data, S: CKKSNormalizationState> DerefMut for CKKSCiphertext<D, S> {
118 fn deref_mut(&mut self) -> &mut Self::Target {
119 &mut self.inner
120 }
121}
122
123impl<D: Data, S: CKKSNormalizationState> LWEInfos for CKKSCiphertext<D, S> {
124 fn base2k(&self) -> Base2K {
125 self.inner.base2k()
126 }
127
128 fn n(&self) -> Degree {
129 self.inner.n()
130 }
131
132 fn size(&self) -> usize {
133 self.inner.size()
134 }
135}
136
137impl<D: Data, S: CKKSNormalizationState> GLWEInfos for CKKSCiphertext<D, S> {
138 fn rank(&self) -> Rank {
139 self.inner.rank()
140 }
141}
142
143impl<D: Data, S: CKKSNormalizationState> CKKSInfos for CKKSCiphertext<D, S> {
144 fn meta(&self) -> CKKSMeta {
145 self.meta
146 }
147
148 fn log_delta(&self) -> usize {
149 self.meta.log_delta()
150 }
151
152 fn log_budget(&self) -> usize {
153 self.meta.log_budget()
154 }
155}
156
157impl<D: Data, S: CKKSNormalizationState> SetCKKSInfos for CKKSCiphertext<D, S> {
158 fn set_meta(&mut self, meta: CKKSMeta) {
159 self.meta = meta;
160 }
161}
162
163impl<D: HostDataRef, S: CKKSNormalizationState> fmt::Display for CKKSCiphertext<D, S> {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 write!(f, "{}", self.inner)
166 }
167}
168
169impl<BE: Backend, D: Data, S: CKKSNormalizationState> GLWEToBackendRef<BE> for CKKSCiphertext<D, S>
170where
171 GLWE<D>: GLWEToBackendRef<BE>,
172{
173 fn to_backend_ref(&self) -> GLWE<BE::BufRef<'_>> {
174 GLWEToBackendRef::to_backend_ref(&self.inner)
175 }
176}
177
178impl<BE: Backend, D: Data, S: CKKSNormalizationState> GLWEToBackendMut<BE> for CKKSCiphertext<D, S>
179where
180 GLWE<D>: GLWEToBackendMut<BE>,
181{
182 fn to_backend_mut(&mut self) -> GLWE<BE::BufMut<'_>> {
183 GLWEToBackendMut::to_backend_mut(&mut self.inner)
184 }
185}
186
187pub struct CKKSCiphertextViewMut<'a, BE: Backend + 'a> {
193 inner: GLWEViewMut<'a, BE>,
194 meta: CKKSMeta,
195}
196
197impl<'a, BE: Backend + 'a> CKKSCiphertextViewMut<'a, BE> {
198 pub(crate) fn from_inner(inner: GLWEViewMut<'a, BE>, meta: CKKSMeta) -> Self {
199 Self { inner, meta }
200 }
201}
202
203impl<'a, BE: Backend + 'a> Deref for CKKSCiphertextViewMut<'a, BE> {
204 type Target = GLWEViewMut<'a, BE>;
205
206 fn deref(&self) -> &Self::Target {
207 &self.inner
208 }
209}
210
211impl<'a, BE: Backend + 'a> DerefMut for CKKSCiphertextViewMut<'a, BE> {
212 fn deref_mut(&mut self) -> &mut Self::Target {
213 &mut self.inner
214 }
215}
216
217impl<'a, BE: Backend + 'a> LWEInfos for CKKSCiphertextViewMut<'a, BE> {
218 fn base2k(&self) -> Base2K {
219 self.inner.base2k()
220 }
221
222 fn n(&self) -> Degree {
223 self.inner.n()
224 }
225
226 fn size(&self) -> usize {
227 self.inner.size()
228 }
229}
230
231impl<'a, BE: Backend + 'a> GLWEInfos for CKKSCiphertextViewMut<'a, BE> {
232 fn rank(&self) -> Rank {
233 self.inner.rank()
234 }
235}
236
237impl<'a, BE: Backend + 'a> CKKSInfos for CKKSCiphertextViewMut<'a, BE> {
238 fn meta(&self) -> CKKSMeta {
239 self.meta
240 }
241
242 fn log_delta(&self) -> usize {
243 self.meta.log_delta()
244 }
245
246 fn log_budget(&self) -> usize {
247 self.meta.log_budget()
248 }
249}
250
251impl<'a, BE: Backend + 'a> SetCKKSInfos for CKKSCiphertextViewMut<'a, BE> {
252 fn set_meta(&mut self, meta: CKKSMeta) {
253 self.meta = meta;
254 }
255}
256
257impl<'a, BE: Backend + 'a> GLWEToBackendRef<BE> for CKKSCiphertextViewMut<'a, BE> {
258 fn to_backend_ref(&self) -> GLWE<BE::BufRef<'_>> {
259 self.inner.to_backend_ref()
260 }
261}
262
263impl<'a, BE: Backend + 'a> GLWEToBackendMut<BE> for CKKSCiphertextViewMut<'a, BE> {
264 fn to_backend_mut(&mut self) -> GLWE<BE::BufMut<'_>> {
265 self.inner.to_backend_mut()
266 }
267}
268
269pub trait ScratchArenaTakeCKKS<'a, BE: Backend>: ScratchArenaTakeCore<'a, BE> + Sized {
271 fn take_ckks_ciphertext_scratch<I>(self, infos: &I, meta: CKKSMeta) -> (CKKSCiphertextViewMut<'a, BE>, Self)
272 where
273 BE: 'a,
274 I: GLWEInfos,
275 {
276 let (inner, scratch) = self.take_glwe_scratch(infos);
277 (CKKSCiphertextViewMut::from_inner(inner, meta), scratch)
278 }
279
280 fn take_ckks_ciphertext_like_scratch<C>(self, ct: &C) -> (CKKSCiphertextViewMut<'a, BE>, Self)
281 where
282 BE: 'a,
283 C: GLWEInfos + CKKSInfos,
284 {
285 self.take_ckks_ciphertext_scratch(ct, ct.meta())
286 }
287
288 fn take_unnormalized_ckks_ciphertext_scratch<I>(
289 self,
290 infos: &I,
291 meta: CKKSMeta,
292 ) -> (UnnormalizedCKKSCiphertext<BE::BufMut<'a>>, Self)
293 where
294 BE: 'a,
295 I: GLWEInfos,
296 {
297 let (inner, scratch) = self.take_glwe_scratch(infos);
298 (UnnormalizedCKKSCiphertext::from_inner(inner.into_inner(), meta), scratch)
299 }
300
301 fn take_unnormalized_ckks_ciphertext_like_scratch<C>(self, ct: &C) -> (UnnormalizedCKKSCiphertext<BE::BufMut<'a>>, Self)
302 where
303 BE: 'a,
304 C: GLWEInfos + CKKSInfos,
305 {
306 self.take_unnormalized_ckks_ciphertext_scratch(ct, ct.meta())
307 }
308}
309
310impl<'a, BE, T> ScratchArenaTakeCKKS<'a, BE> for T
311where
312 BE: Backend + 'a,
313 T: ScratchArenaTakeCore<'a, BE>,
314{
315}
316
317pub trait CKKSMaintainOps {
319 fn ckks_reallocate_limbs_checked(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()>;
337
338 fn ckks_compact_limbs(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()>;
351
352 fn ckks_compact_limbs_copy<D>(&self, ct: &CKKSCiphertext<D>) -> Result<CKKSCiphertext<Vec<u8>>>
365 where
366 D: HostDataRef;
367}
368
369#[doc(hidden)]
370pub trait CKKSMaintainOpsDefault<BE: Backend> {
371 fn ckks_reallocate_limbs_checked_default(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()> {
372 let base2k = ct.base2k().as_usize();
373 let required_limbs = ct.effective_k().div_ceil(base2k);
374 anyhow::ensure!(
375 size >= required_limbs,
376 CKKSCompositionError::LimbReallocationShrinksBelowMetadata {
377 max_k: ct.max_k().as_usize(),
378 log_delta: ct.log_delta(),
379 base2k,
380 requested_limbs: size,
381 }
382 );
383 ct.data_mut().reallocate_limbs(size);
384 Ok(())
385 }
386
387 fn ckks_compact_limbs_default(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()> {
388 let size = ct.effective_k().div_ceil(ct.base2k().as_usize());
389 self.ckks_reallocate_limbs_checked_default(ct, size)?;
390 Ok(())
391 }
392}
393
394#[macro_export]
395macro_rules! impl_ckks_maintain_ops_defaults {
396 ($be:ty) => {
397 impl $crate::layouts::ciphertext::CKKSMaintainOpsDefault<$be> for ::poulpy_hal::layouts::Module<$be> {}
398 };
399}
400pub use crate::impl_ckks_maintain_ops_defaults;
401
402impl<BE: Backend> CKKSMaintainOps for Module<BE>
403where
404 BE: HostBackend<OwnedBuf = Vec<u8>>,
405 Module<BE>: CKKSMaintainOpsDefault<BE> + CKKSModuleAlloc<BE>,
406{
407 fn ckks_reallocate_limbs_checked(&self, ct: &mut CKKSCiphertext<Vec<u8>>, size: usize) -> Result<()> {
408 self.ckks_reallocate_limbs_checked_default(ct, size)
409 }
410
411 fn ckks_compact_limbs(&self, ct: &mut CKKSCiphertext<Vec<u8>>) -> Result<()> {
412 self.ckks_compact_limbs_default(ct)
413 }
414
415 fn ckks_compact_limbs_copy<D>(&self, ct: &CKKSCiphertext<D>) -> Result<CKKSCiphertext<Vec<u8>>>
416 where
417 D: HostDataRef,
418 {
419 let size = ct.effective_k().div_ceil(ct.base2k().as_usize());
420 let mut compact = self.ckks_ciphertext_alloc_from_infos(ct);
421 compact.meta = ct.meta();
422 self.ckks_reallocate_limbs_checked_default(&mut compact, size)?;
423 let dst_len = compact.data().data.len();
424 compact.data_mut().data.copy_from_slice(&ct.data().data.as_ref()[..dst_len]);
425 Ok(compact)
426 }
427}
428
429pub type UnnormalizedCKKSCiphertext<D> = CKKSCiphertext<D, Unnormalized>;
443
444impl<D: Data> CKKSCiphertext<D, Unnormalized> {
445 pub fn new(ct: CKKSCiphertext<D>) -> Self {
447 Self::from_inner(ct.inner, ct.meta)
448 }
449
450 pub fn normalize<M, BE>(self, module: &M, scratch: &mut ScratchArena<'_, BE>) -> CKKSCiphertext<D>
457 where
458 BE: Backend,
459 M: GLWENormalize<BE>,
460 GLWE<D>: GLWEToBackendMut<BE>,
461 {
462 let mut normalized = CKKSCiphertext::<D>::from_inner(self.inner, self.meta);
463 module.glwe_normalize_assign(&mut normalized, scratch);
464 normalized
465 }
466}
467
468pub struct UnnormalizedCKKSCiphertextRefMut<'a, D: Data> {
469 pub(crate) inner: &'a mut CKKSCiphertext<D>,
470}
471
472impl<'a, D: Data> UnnormalizedCKKSCiphertextRefMut<'a, D> {
473 pub(crate) fn new(inner: &'a mut CKKSCiphertext<D>) -> Self {
474 Self { inner }
475 }
476
477 pub(crate) fn normalize<M, BE>(self, module: &M, scratch: &mut ScratchArena<'_, BE>)
478 where
479 BE: Backend,
480 M: GLWENormalize<BE>,
481 CKKSCiphertext<D>: GLWEToBackendMut<BE>,
482 {
483 module.glwe_normalize_assign(self.inner, scratch);
484 }
485}
486
487pub(crate) trait CKKSOffset: LWEInfos + CKKSInfos {
488 fn offset_unary<A>(&self, a: &A) -> usize
489 where
490 A: LWEInfos + CKKSInfos,
491 {
492 crate::ckks_offset_unary(self, a)
493 }
494
495 fn offset_binary<A, B>(&self, a: &A, b: &B) -> usize
496 where
497 A: LWEInfos + CKKSInfos,
498 B: LWEInfos + CKKSInfos,
499 {
500 crate::ckks_offset_binary(self, a, b)
501 }
502}
503
504impl<T> CKKSOffset for T where T: LWEInfos + CKKSInfos {}