1use anyhow::Result;
2use poulpy_core::{
3 GLWECopy, GLWEMulConst, GLWEMulPlain, GLWERotate, GLWETensoring, ScratchArenaTakeCore,
4 layouts::{
5 GGLWEInfos, GLWE, GLWEInfos, GLWELayout, GLWEPlaintextLayout, GLWETensor, GLWEToBackendMut, GLWEToBackendRef, LWEInfos,
6 ModuleCoreAlloc, TorusPrecision, prepared::GLWETensorKeyPreparedToBackendRef,
7 },
8};
9use poulpy_hal::{
10 api::VecZnxCopyBackend,
11 layouts::{Backend, ScratchArena},
12};
13
14use crate::{CKKSInfos, CKKSMeta, SetCKKSInfos, checked_log_budget_sub, checked_mul_ct_log_budget, checked_mul_pt_log_budget};
15
16pub trait CKKSMulDefault<BE: Backend> {
17 fn ckks_mul_tmp_bytes_default<R, T>(&self, res: &R, tsk: &T) -> usize
18 where
19 R: GLWEInfos,
20 T: GGLWEInfos,
21 Self: GLWETensoring<BE>,
22 {
23 let glwe_layout = GLWELayout {
24 n: res.n(),
25 base2k: res.base2k(),
26 k: TorusPrecision(res.max_k().as_u32()),
27 rank: res.rank(),
28 };
29
30 let lvl_0 = GLWETensor::bytes_of_from_infos(&glwe_layout);
31 let lvl_1 = self
32 .glwe_tensor_apply_tmp_bytes(&glwe_layout, res, res)
33 .max(self.glwe_tensor_relinearize_tmp_bytes(res, &glwe_layout, tsk));
34
35 lvl_0 + lvl_1
36 }
37
38 fn ckks_mul_into_default<Dst, A, B, T>(
39 &self,
40 dst: &mut Dst,
41 a: &A,
42 b: &B,
43 tsk: &T,
44 scratch: &mut ScratchArena<'_, BE>,
45 ) -> Result<()>
46 where
47 Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
48 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
49 A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
50 B: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
51 T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
52 {
53 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, a, b)?;
54
55 let tensor_layout = GLWELayout {
56 n: dst.n(),
57 base2k: dst.base2k(),
58 k: a.max_k().max(b.max_k()),
59 rank: dst.rank(),
60 };
61 let scratch_local = scratch.borrow();
62 let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
63 self.glwe_tensor_apply(
64 cnv_offset,
65 &mut tmp,
66 a,
67 a.effective_k(),
68 b,
69 b.effective_k(),
70 &mut scratch_local,
71 );
72 self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
73
74 dst.set_log_budget(res_log_budget);
75 dst.set_log_delta(res_log_delta);
76 Ok(())
77 }
78
79 fn ckks_mul_assign_default<Dst, A, T>(&self, dst: &mut Dst, a: &A, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
80 where
81 Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
82 Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
83 A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
84 T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
85 {
86 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, dst, a)?;
87
88 let tensor_layout = GLWELayout {
89 n: dst.n(),
90 base2k: dst.base2k(),
91 k: dst.max_k().max(a.max_k()),
92 rank: dst.rank(),
93 };
94 let scratch_local = scratch.borrow();
95 let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
96 self.glwe_tensor_apply(
97 cnv_offset,
98 &mut tmp,
99 &*dst,
100 dst.effective_k(),
101 a,
102 a.effective_k(),
103 &mut scratch_local,
104 );
105 self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
106
107 dst.set_log_budget(res_log_budget);
108 dst.set_log_delta(res_log_delta);
109 Ok(())
110 }
111
112 fn ckks_square_tmp_bytes_default<R, T>(&self, res: &R, tsk: &T) -> usize
113 where
114 R: GLWEInfos,
115 T: GGLWEInfos,
116 Self: GLWETensoring<BE>,
117 {
118 let glwe_layout = GLWELayout {
119 n: res.n(),
120 base2k: res.base2k(),
121 k: TorusPrecision(res.max_k().as_u32()),
122 rank: res.rank(),
123 };
124
125 let lvl_0 = GLWETensor::bytes_of_from_infos(&glwe_layout);
126 let lvl_1 = self
127 .glwe_tensor_square_apply_tmp_bytes(&glwe_layout, res)
128 .max(self.glwe_tensor_relinearize_tmp_bytes(res, &glwe_layout, tsk));
129
130 lvl_0 + lvl_1
131 }
132
133 fn ckks_square_into_default<Dst, A, T>(&self, dst: &mut Dst, a: &A, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
134 where
135 Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
136 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
137 A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
138 T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
139 {
140 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, a, a)?;
141
142 let tensor_layout = GLWELayout {
143 n: dst.n(),
144 base2k: dst.base2k(),
145 k: a.max_k(),
146 rank: dst.rank(),
147 };
148 let scratch_local = scratch.borrow();
149 let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
150 self.glwe_tensor_square_apply(cnv_offset, &mut tmp, a, a.effective_k(), &mut scratch_local);
151 self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
152
153 dst.set_log_budget(res_log_budget);
154 dst.set_log_delta(res_log_delta);
155 Ok(())
156 }
157
158 fn ckks_square_assign_default<Dst, T>(&self, dst: &mut Dst, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
159 where
160 Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
161 Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
162 T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
163 {
164 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, dst, dst)?;
165
166 let tensor_layout = GLWELayout {
167 n: dst.n(),
168 base2k: dst.base2k(),
169 k: dst.max_k(),
170 rank: dst.rank(),
171 };
172 let scratch_local = scratch.borrow();
173 let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
174 self.glwe_tensor_square_apply(cnv_offset, &mut tmp, &*dst, dst.effective_k(), &mut scratch_local);
175 self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
176
177 dst.set_log_budget(res_log_budget);
178 dst.set_log_delta(res_log_delta);
179 Ok(())
180 }
181
182 fn ckks_mul_pt_vec_tmp_bytes_default<R, A>(&self, res: &R, a: &A, b: &CKKSMeta) -> usize
183 where
184 R: GLWEInfos,
185 A: GLWEInfos,
186 Self: GLWEMulPlain<BE>,
187 {
188 let b_infos = GLWEPlaintextLayout {
189 n: res.n(),
190 base2k: res.base2k(),
191 k: b.min_k(res.base2k()),
192 };
193 self.glwe_mul_plain_tmp_bytes(res, a, &b_infos)
194 }
195
196 fn ckks_mul_pt_const_tmp_bytes_default<R, A>(&self, res: &R, a: &A, b: &CKKSMeta) -> usize
197 where
198 R: GLWEInfos,
199 A: GLWEInfos,
200 Self: GLWEMulConst<BE> + GLWERotate<BE>,
201 {
202 let b_infos = GLWEPlaintextLayout {
203 n: res.n(),
204 base2k: res.base2k(),
205 k: b.min_k(res.base2k()),
206 };
207 GLWE::<Vec<u8>>::bytes_of_from_infos(res)
208 + self
209 .glwe_mul_const_tmp_bytes(res, a, &b_infos)
210 .max(self.glwe_rotate_tmp_bytes())
211 }
212
213 fn ckks_mul_pt_vec_into_default<Dst, A, P>(
214 &self,
215 dst: &mut Dst,
216 a: &A,
217 pt: &P,
218 scratch: &mut ScratchArena<'_, BE>,
219 ) -> Result<()>
220 where
221 P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
222 Self: GLWECopy<BE> + GLWEMulPlain<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf> + VecZnxCopyBackend<BE>,
223 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
224 A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
225 {
226 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, a, pt)?;
227 self.glwe_mul_plain(cnv_offset, dst, a, a.effective_k(), pt, pt.max_k().as_usize(), scratch);
228 dst.set_log_budget(res_log_budget);
229 dst.set_log_delta(res_log_delta);
230 Ok(())
231 }
232
233 fn ckks_mul_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
234 where
235 P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
236 Self: GLWECopy<BE> + GLWEMulPlain<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf> + VecZnxCopyBackend<BE>,
237 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
238 {
239 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, dst, pt)?;
240 let dst_effective_k = dst.effective_k();
241 self.glwe_mul_plain_assign(cnv_offset, dst, dst_effective_k, pt, pt.max_k().as_usize(), scratch);
242 dst.set_log_budget(res_log_budget);
243 dst.set_log_delta(res_log_delta);
244 Ok(())
245 }
246
247 fn ckks_mul_pt_const_into_default<Dst, A, P>(
248 &self,
249 dst: &mut Dst,
250 a: &A,
251 pt: &P,
252 pt_coeff: usize,
253 scratch: &mut ScratchArena<'_, BE>,
254 ) -> Result<()>
255 where
256 P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
257 Self: GLWEMulConst<BE>,
258 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
259 A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
260 {
261 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, a, pt)?;
262 self.glwe_mul_const(cnv_offset, dst, a, pt, pt_coeff, scratch);
263
264 dst.set_log_budget(res_log_budget);
265 dst.set_log_delta(res_log_delta);
266 Ok(())
267 }
268
269 fn ckks_mul_pt_const_assign_default<Dst, P>(
270 &self,
271 dst: &mut Dst,
272 cnst: &P,
273 cnst_coeff: usize,
274 scratch: &mut ScratchArena<'_, BE>,
275 ) -> Result<()>
276 where
277 P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
278 Self: GLWEMulConst<BE>,
279 Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
280 {
281 let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, dst, cnst)?;
282
283 self.glwe_mul_const_assign(cnv_offset, dst, cnst, cnst_coeff, scratch);
284
285 dst.set_log_budget(res_log_budget);
286 dst.set_log_delta(res_log_delta);
287 Ok(())
288 }
289}
290
291fn get_mul_ct_params<R, A, B>(res: &R, a: &A, b: &B) -> Result<(usize, usize, usize)>
292where
293 R: LWEInfos + CKKSInfos,
294 A: LWEInfos + CKKSInfos,
295 B: LWEInfos + CKKSInfos,
296{
297 let res_log_budget = checked_mul_ct_log_budget("mul", a.log_budget(), b.log_budget(), a.log_delta(), b.log_delta())?;
298 let res_log_delta = a.log_delta().min(b.log_delta());
299
300 let res_offset = (res_log_budget + res_log_delta).saturating_sub(res.max_k().as_usize());
301 let cnv_offset = a.effective_k().max(b.effective_k()) + res_offset;
309
310 Ok((
311 checked_log_budget_sub("mul", res_log_budget, res_offset)?,
312 res_log_delta,
313 cnv_offset,
314 ))
315}
316
317fn get_mul_pt_params<R, A, B>(res: &R, a: &A, b: &B) -> Result<(usize, usize, usize)>
318where
319 R: LWEInfos + CKKSInfos,
320 A: LWEInfos + CKKSInfos,
321 B: LWEInfos + CKKSInfos,
322{
323 let res_log_budget = checked_mul_pt_log_budget("mul", a.log_budget(), b.log_budget(), a.log_delta(), b.log_delta())?;
324 let res_log_delta = a.log_delta();
325 let res_offset = (res_log_budget + res_log_delta).saturating_sub(res.max_k().as_usize());
326 let cnv_offset = b.max_k().as_usize() + res_offset;
327
328 Ok((
329 checked_log_budget_sub("mul", res_log_budget, res_offset)?,
330 res_log_delta,
331 cnv_offset,
332 ))
333}