1use anyhow::Result;
2use poulpy_core::{
3 GLWEAdd, GLWENormalize, GLWEShift,
4 layouts::{Base2K, GLWEPlaintext, GLWEToBackendMut, GLWEToBackendRef, LWEInfos},
5};
6use poulpy_hal::{
7 api::{VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshTmpBytes},
8 layouts::{Backend, ScratchArena, VecZnx},
9};
10
11use crate::{
12 CKKSInfos, CKKSMeta, SetCKKSInfos, checked_log_budget_sub, ckks_offset_binary, ckks_offset_unary,
13 default::CKKSPlaintextDefault,
14 layouts::{CKKSPlaintext, CKKSPlaintextVecHostCodec},
15};
16
17pub(crate) fn ckks_one_pt<BE>(base2k: Base2K) -> Result<CKKSPlaintext<BE::OwnedBuf>>
18where
19 BE: Backend,
20{
21 let meta = CKKSMeta {
22 log_delta: 1,
23 log_budget: 0,
24 };
25
26 let mut host_pt = CKKSPlaintext::from_inner(
27 GLWEPlaintext::alloc_with_meta(1usize.into(), base2k, meta.min_k(base2k)),
28 meta,
29 );
30 host_pt.encode_host_floats(&[1.0f64])?;
31
32 let shape = host_pt.inner.data.shape();
33 let backend_inner = GLWEPlaintext {
34 data: VecZnx::from_data_with_max_size(
35 BE::from_host_bytes(host_pt.inner.data.data.as_ref()),
36 shape.n(),
37 shape.cols(),
38 shape.size(),
39 shape.max_size(),
40 ),
41 base2k,
42 };
43 Ok(CKKSPlaintext::from_inner(backend_inner, meta))
44}
45
46pub trait CKKSAddDefault<BE: Backend> {
47 fn ckks_add_tmp_bytes_default(&self) -> usize
48 where
49 Self: GLWEShift<BE> + GLWENormalize<BE>,
50 {
51 self.glwe_shift_tmp_bytes().max(self.glwe_normalize_tmp_bytes())
52 }
53
54 fn ckks_add_pt_vec_tmp_bytes_default(&self) -> usize
55 where
56 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
57 {
58 self.glwe_shift_tmp_bytes()
59 .max(self.vec_znx_rsh_tmp_bytes())
60 .max(self.glwe_normalize_tmp_bytes())
61 }
62
63 fn ckks_add_pt_const_tmp_bytes_default(&self) -> usize
64 where
65 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
66 {
67 self.glwe_shift_tmp_bytes()
68 .max(self.glwe_normalize_tmp_bytes())
69 .max(self.vec_znx_rsh_tmp_bytes())
70 }
71
72 fn ckks_add_into_default<Dst, A, B>(&self, dst: &mut Dst, a: &A, b: &B, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
73 where
74 Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
75 Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
76 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
77 B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
78 {
79 self.ckks_add_into_unsafe_default(dst, a, b, scratch)?;
80 self.glwe_normalize_assign(dst, scratch);
81 Ok(())
82 }
83
84 fn ckks_add_into_unsafe_default<Dst, A, B>(
85 &self,
86 dst: &mut Dst,
87 a: &A,
88 b: &B,
89 scratch: &mut ScratchArena<'_, BE>,
90 ) -> Result<()>
91 where
92 Self: GLWEAdd<BE> + GLWEShift<BE>,
93 Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
94 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
95 B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
96 {
97 let offset = ckks_offset_binary(dst, a, b);
98
99 if offset == 0 && a.log_budget() == b.log_budget() {
100 self.glwe_add_into(dst, a, b);
101 } else if a.log_budget() <= b.log_budget() {
102 self.glwe_lsh(dst, a, offset, scratch);
103 self.glwe_lsh_add(dst, b, b.log_budget() - a.log_budget() + offset, scratch);
104 } else {
105 self.glwe_lsh(dst, b, offset, scratch);
106 self.glwe_lsh_add(dst, a, a.log_budget() - b.log_budget() + offset, scratch);
107 }
108
109 let log_budget = checked_log_budget_sub("add", a.log_budget().min(b.log_budget()), offset)?;
110 dst.set_log_delta(a.log_delta().min(b.log_delta()));
111 dst.set_log_budget(log_budget);
112 Ok(())
113 }
114
115 fn ckks_add_assign_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
116 where
117 Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
118 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
119 A: GLWEToBackendRef<BE> + CKKSInfos,
120 {
121 self.ckks_add_assign_unsafe_default(dst, a, scratch)?;
122 self.glwe_normalize_assign(dst, scratch);
123 Ok(())
124 }
125
126 fn ckks_add_assign_unsafe_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
127 where
128 Self: GLWEAdd<BE> + GLWEShift<BE>,
129 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
130 A: GLWEToBackendRef<BE> + CKKSInfos,
131 {
132 let dst_log_budget = dst.log_budget();
133
134 if dst_log_budget < a.log_budget() {
135 self.glwe_lsh_add(dst, a, a.log_budget() - dst_log_budget, scratch);
136 } else if dst_log_budget > a.log_budget() {
137 self.glwe_lsh_assign(dst, dst_log_budget - a.log_budget(), scratch);
138 self.glwe_add_assign(dst, a);
139 } else {
140 self.glwe_add_assign(dst, a);
141 }
142
143 dst.set_log_budget(dst_log_budget.min(a.log_budget()));
144 dst.set_log_delta(dst.log_delta().min(a.log_delta()));
145 Ok(())
146 }
147
148 fn ckks_add_one_assign_default<Dst>(&self, dst: &mut Dst, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
149 where
150 Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
151 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
152 {
153 let one = ckks_one_pt::<BE>(dst.base2k())?;
154 self.ckks_add_pt_const_assign_default(dst, 0, &one, 0, scratch)
155 }
156
157 fn ckks_add_pt_vec_into_default<Dst, A, P>(
158 &self,
159 dst: &mut Dst,
160 a: &A,
161 pt: &P,
162 scratch: &mut ScratchArena<'_, BE>,
163 ) -> Result<()>
164 where
165 Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
166 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
167 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
168 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
169 {
170 self.ckks_add_pt_vec_into_unsafe_default(dst, a, pt, scratch)?;
171 self.glwe_normalize_assign(dst, scratch);
172 Ok(())
173 }
174
175 fn ckks_add_pt_vec_into_unsafe_default<Dst, A, P>(
176 &self,
177 dst: &mut Dst,
178 a: &A,
179 pt: &P,
180 scratch: &mut ScratchArena<'_, BE>,
181 ) -> Result<()>
182 where
183 Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + CKKSPlaintextDefault<BE>,
184 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
185 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
186 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
187 {
188 let offset = ckks_offset_unary(dst, a);
189 self.glwe_lsh(dst, a, offset, scratch);
190 dst.set_meta(a.meta());
191 dst.set_log_budget(checked_log_budget_sub("add_pt_vec", a.log_budget(), offset)?);
192 self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
193 Ok(())
194 }
195
196 fn ckks_add_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
197 where
198 Self: VecZnxRshAddIntoBackend<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
199 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
200 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
201 {
202 self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
203 self.glwe_normalize_assign(dst, scratch);
204 Ok(())
205 }
206
207 fn ckks_add_pt_vec_assign_unsafe_default<Dst, P>(
208 &self,
209 dst: &mut Dst,
210 pt: &P,
211 scratch: &mut ScratchArena<'_, BE>,
212 ) -> Result<()>
213 where
214 Self: VecZnxRshAddIntoBackend<BE> + CKKSPlaintextDefault<BE>,
215 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
216 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
217 {
218 CKKSPlaintextDefault::ckks_add_pt_vec_into_default(self, dst, pt, scratch)?;
219 Ok(())
220 }
221
222 fn ckks_add_pt_const_into_default<Dst, A, P>(
223 &self,
224 dst: &mut Dst,
225 a: &A,
226 dst_coeff: usize,
227 cst: &P,
228 const_coeff: usize,
229 scratch: &mut ScratchArena<'_, BE>,
230 ) -> Result<()>
231 where
232 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
233 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
234 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
235 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
236 {
237 self.ckks_add_pt_const_into_unsafe_default(dst, a, dst_coeff, cst, const_coeff, scratch)?;
238 self.glwe_normalize_assign(dst, scratch);
239 Ok(())
240 }
241
242 fn ckks_add_pt_const_into_unsafe_default<Dst, A, P>(
243 &self,
244 dst: &mut Dst,
245 a: &A,
246 dst_coeff: usize,
247 cst: &P,
248 const_coeff: usize,
249 scratch: &mut ScratchArena<'_, BE>,
250 ) -> Result<()>
251 where
252 Self: GLWEShift<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
253 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
254 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
255 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
256 {
257 let offset = ckks_offset_unary(dst, a);
258 self.glwe_lsh(dst, a, offset, scratch);
259 dst.set_meta(a.meta());
260 dst.set_log_budget(checked_log_budget_sub("add_const", a.log_budget(), offset)?);
261 self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)
262 }
263
264 fn ckks_add_pt_const_assign_default<Dst, P>(
265 &self,
266 dst: &mut Dst,
267 dst_coeff: usize,
268 cst: &P,
269 const_coeff: usize,
270 scratch: &mut ScratchArena<'_, BE>,
271 ) -> Result<()>
272 where
273 Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
274 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
275 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
276 {
277 self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)?;
278 self.glwe_normalize_assign(dst, scratch);
279 Ok(())
280 }
281
282 fn ckks_add_pt_const_assign_unsafe_default<Dst, P>(
283 &self,
284 dst: &mut Dst,
285 dst_coeff: usize,
286 cst: &P,
287 const_coeff: usize,
288 scratch: &mut ScratchArena<'_, BE>,
289 ) -> Result<()>
290 where
291 Self: VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
292 Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
293 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
294 {
295 CKKSPlaintextDefault::ckks_add_pt_const_into_default(self, dst, dst_coeff, cst, const_coeff, scratch)?;
296 Ok(())
297 }
298}