Skip to main content

poulpy_ckks/default/
mul.rs

1use anyhow::Result;
2use poulpy_core::{
3    GLWECopy, GLWEMulConst, GLWEMulPlain, GLWERotate, GLWETensoring, ScratchArenaTakeCore,
4    layouts::{
5        GGLWEInfos, GLWE, GLWEInfos, GLWELayout, GLWEPlaintextLayout, GLWETensor, GLWEToBackendMut, GLWEToBackendRef, LWEInfos,
6        ModuleCoreAlloc, TorusPrecision, prepared::GLWETensorKeyPreparedToBackendRef,
7    },
8};
9use poulpy_hal::{
10    api::VecZnxCopyBackend,
11    layouts::{Backend, ScratchArena},
12};
13
14use crate::{CKKSInfos, CKKSMeta, SetCKKSInfos, checked_log_budget_sub, checked_mul_ct_log_budget, checked_mul_pt_log_budget};
15
16pub trait CKKSMulDefault<BE: Backend> {
17    fn ckks_mul_tmp_bytes_default<R, T>(&self, res: &R, tsk: &T) -> usize
18    where
19        R: GLWEInfos,
20        T: GGLWEInfos,
21        Self: GLWETensoring<BE>,
22    {
23        let glwe_layout = GLWELayout {
24            n: res.n(),
25            base2k: res.base2k(),
26            k: TorusPrecision(res.max_k().as_u32()),
27            rank: res.rank(),
28        };
29
30        let lvl_0 = GLWETensor::bytes_of_from_infos(&glwe_layout);
31        let lvl_1 = self
32            .glwe_tensor_apply_tmp_bytes(&glwe_layout, res, res)
33            .max(self.glwe_tensor_relinearize_tmp_bytes(res, &glwe_layout, tsk));
34
35        lvl_0 + lvl_1
36    }
37
38    fn ckks_mul_into_default<Dst, A, B, T>(
39        &self,
40        dst: &mut Dst,
41        a: &A,
42        b: &B,
43        tsk: &T,
44        scratch: &mut ScratchArena<'_, BE>,
45    ) -> Result<()>
46    where
47        Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
48        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
49        A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
50        B: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
51        T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
52    {
53        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, a, b)?;
54
55        let tensor_layout = GLWELayout {
56            n: dst.n(),
57            base2k: dst.base2k(),
58            k: a.max_k().max(b.max_k()),
59            rank: dst.rank(),
60        };
61        let scratch_local = scratch.borrow();
62        let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
63        self.glwe_tensor_apply(
64            cnv_offset,
65            &mut tmp,
66            a,
67            a.effective_k(),
68            b,
69            b.effective_k(),
70            &mut scratch_local,
71        );
72        self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
73
74        dst.set_log_budget(res_log_budget);
75        dst.set_log_delta(res_log_delta);
76        Ok(())
77    }
78
79    fn ckks_mul_assign_default<Dst, A, T>(&self, dst: &mut Dst, a: &A, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
80    where
81        Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
82        Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
83        A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
84        T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
85    {
86        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, dst, a)?;
87
88        let tensor_layout = GLWELayout {
89            n: dst.n(),
90            base2k: dst.base2k(),
91            k: dst.max_k().max(a.max_k()),
92            rank: dst.rank(),
93        };
94        let scratch_local = scratch.borrow();
95        let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
96        self.glwe_tensor_apply(
97            cnv_offset,
98            &mut tmp,
99            &*dst,
100            dst.effective_k(),
101            a,
102            a.effective_k(),
103            &mut scratch_local,
104        );
105        self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
106
107        dst.set_log_budget(res_log_budget);
108        dst.set_log_delta(res_log_delta);
109        Ok(())
110    }
111
112    fn ckks_square_tmp_bytes_default<R, T>(&self, res: &R, tsk: &T) -> usize
113    where
114        R: GLWEInfos,
115        T: GGLWEInfos,
116        Self: GLWETensoring<BE>,
117    {
118        let glwe_layout = GLWELayout {
119            n: res.n(),
120            base2k: res.base2k(),
121            k: TorusPrecision(res.max_k().as_u32()),
122            rank: res.rank(),
123        };
124
125        let lvl_0 = GLWETensor::bytes_of_from_infos(&glwe_layout);
126        let lvl_1 = self
127            .glwe_tensor_square_apply_tmp_bytes(&glwe_layout, res)
128            .max(self.glwe_tensor_relinearize_tmp_bytes(res, &glwe_layout, tsk));
129
130        lvl_0 + lvl_1
131    }
132
133    fn ckks_square_into_default<Dst, A, T>(&self, dst: &mut Dst, a: &A, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
134    where
135        Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
136        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
137        A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
138        T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
139    {
140        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, a, a)?;
141
142        let tensor_layout = GLWELayout {
143            n: dst.n(),
144            base2k: dst.base2k(),
145            k: a.max_k(),
146            rank: dst.rank(),
147        };
148        let scratch_local = scratch.borrow();
149        let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
150        self.glwe_tensor_square_apply(cnv_offset, &mut tmp, a, a.effective_k(), &mut scratch_local);
151        self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
152
153        dst.set_log_budget(res_log_budget);
154        dst.set_log_delta(res_log_delta);
155        Ok(())
156    }
157
158    fn ckks_square_assign_default<Dst, T>(&self, dst: &mut Dst, tsk: &T, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
159    where
160        Self: GLWETensoring<BE> + GLWECopy<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf>,
161        Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
162        T: GLWETensorKeyPreparedToBackendRef<BE> + GGLWEInfos,
163    {
164        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_ct_params(dst, dst, dst)?;
165
166        let tensor_layout = GLWELayout {
167            n: dst.n(),
168            base2k: dst.base2k(),
169            k: dst.max_k(),
170            rank: dst.rank(),
171        };
172        let scratch_local = scratch.borrow();
173        let (mut tmp, mut scratch_local) = scratch_local.take_glwe_tensor_scratch(&tensor_layout);
174        self.glwe_tensor_square_apply(cnv_offset, &mut tmp, &*dst, dst.effective_k(), &mut scratch_local);
175        self.glwe_tensor_relinearize(dst, &tmp, tsk, tmp.size() + tsk.dsize().as_usize(), &mut scratch_local);
176
177        dst.set_log_budget(res_log_budget);
178        dst.set_log_delta(res_log_delta);
179        Ok(())
180    }
181
182    fn ckks_mul_pt_vec_tmp_bytes_default<R, A>(&self, res: &R, a: &A, b: &CKKSMeta) -> usize
183    where
184        R: GLWEInfos,
185        A: GLWEInfos,
186        Self: GLWEMulPlain<BE>,
187    {
188        let b_infos = GLWEPlaintextLayout {
189            n: res.n(),
190            base2k: res.base2k(),
191            k: b.min_k(res.base2k()),
192        };
193        self.glwe_mul_plain_tmp_bytes(res, a, &b_infos)
194    }
195
196    fn ckks_mul_pt_const_tmp_bytes_default<R, A>(&self, res: &R, a: &A, b: &CKKSMeta) -> usize
197    where
198        R: GLWEInfos,
199        A: GLWEInfos,
200        Self: GLWEMulConst<BE> + GLWERotate<BE>,
201    {
202        let b_infos = GLWEPlaintextLayout {
203            n: res.n(),
204            base2k: res.base2k(),
205            k: b.min_k(res.base2k()),
206        };
207        GLWE::<Vec<u8>>::bytes_of_from_infos(res)
208            + self
209                .glwe_mul_const_tmp_bytes(res, a, &b_infos)
210                .max(self.glwe_rotate_tmp_bytes())
211    }
212
213    fn ckks_mul_pt_vec_into_default<Dst, A, P>(
214        &self,
215        dst: &mut Dst,
216        a: &A,
217        pt: &P,
218        scratch: &mut ScratchArena<'_, BE>,
219    ) -> Result<()>
220    where
221        P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
222        Self: GLWECopy<BE> + GLWEMulPlain<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf> + VecZnxCopyBackend<BE>,
223        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
224        A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
225    {
226        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, a, pt)?;
227        self.glwe_mul_plain(cnv_offset, dst, a, a.effective_k(), pt, pt.max_k().as_usize(), scratch);
228        dst.set_log_budget(res_log_budget);
229        dst.set_log_delta(res_log_delta);
230        Ok(())
231    }
232
233    fn ckks_mul_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
234    where
235        P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
236        Self: GLWECopy<BE> + GLWEMulPlain<BE> + ModuleCoreAlloc<OwnedBuf = BE::OwnedBuf> + VecZnxCopyBackend<BE>,
237        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
238    {
239        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, dst, pt)?;
240        let dst_effective_k = dst.effective_k();
241        self.glwe_mul_plain_assign(cnv_offset, dst, dst_effective_k, pt, pt.max_k().as_usize(), scratch);
242        dst.set_log_budget(res_log_budget);
243        dst.set_log_delta(res_log_delta);
244        Ok(())
245    }
246
247    fn ckks_mul_pt_const_into_default<Dst, A, P>(
248        &self,
249        dst: &mut Dst,
250        a: &A,
251        pt: &P,
252        pt_coeff: usize,
253        scratch: &mut ScratchArena<'_, BE>,
254    ) -> Result<()>
255    where
256        P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
257        Self: GLWEMulConst<BE>,
258        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
259        A: GLWEToBackendRef<BE> + CKKSInfos + GLWEInfos,
260    {
261        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, a, pt)?;
262        self.glwe_mul_const(cnv_offset, dst, a, pt, pt_coeff, scratch);
263
264        dst.set_log_budget(res_log_budget);
265        dst.set_log_delta(res_log_delta);
266        Ok(())
267    }
268
269    fn ckks_mul_pt_const_assign_default<Dst, P>(
270        &self,
271        dst: &mut Dst,
272        cnst: &P,
273        cnst_coeff: usize,
274        scratch: &mut ScratchArena<'_, BE>,
275    ) -> Result<()>
276    where
277        P: GLWEToBackendRef<BE> + LWEInfos + GLWEInfos + CKKSInfos,
278        Self: GLWEMulConst<BE>,
279        Dst: GLWEToBackendMut<BE> + GLWEToBackendRef<BE> + CKKSInfos + SetCKKSInfos + GLWEInfos,
280    {
281        let (res_log_budget, res_log_delta, cnv_offset) = get_mul_pt_params(dst, dst, cnst)?;
282
283        self.glwe_mul_const_assign(cnv_offset, dst, cnst, cnst_coeff, scratch);
284
285        dst.set_log_budget(res_log_budget);
286        dst.set_log_delta(res_log_delta);
287        Ok(())
288    }
289}
290
291fn get_mul_ct_params<R, A, B>(res: &R, a: &A, b: &B) -> Result<(usize, usize, usize)>
292where
293    R: LWEInfos + CKKSInfos,
294    A: LWEInfos + CKKSInfos,
295    B: LWEInfos + CKKSInfos,
296{
297    let res_log_budget = checked_mul_ct_log_budget("mul", a.log_budget(), b.log_budget(), a.log_delta(), b.log_delta())?;
298    let res_log_delta = a.log_delta().min(b.log_delta());
299
300    let res_offset = (res_log_budget + res_log_delta).saturating_sub(res.max_k().as_usize());
301    // Addition/subtraction align to the shared, lower effective precision
302    // (`ckks_offset_binary` uses `min`). Multiplication is different: the
303    // bivariate convolution must traverse every live input limb, so the
304    // convolution offset starts after the wider operand span and then skips any
305    // extra limbs that cannot fit in `res`. This matches the already-rescaled
306    // multiplication rule documented by `CKKSMulOps` and the bivariate Torus
307    // analysis cited in the README/ePrint 2023/771.
308    let cnv_offset = a.effective_k().max(b.effective_k()) + res_offset;
309
310    Ok((
311        checked_log_budget_sub("mul", res_log_budget, res_offset)?,
312        res_log_delta,
313        cnv_offset,
314    ))
315}
316
317fn get_mul_pt_params<R, A, B>(res: &R, a: &A, b: &B) -> Result<(usize, usize, usize)>
318where
319    R: LWEInfos + CKKSInfos,
320    A: LWEInfos + CKKSInfos,
321    B: LWEInfos + CKKSInfos,
322{
323    let res_log_budget = checked_mul_pt_log_budget("mul", a.log_budget(), b.log_budget(), a.log_delta(), b.log_delta())?;
324    let res_log_delta = a.log_delta();
325    let res_offset = (res_log_budget + res_log_delta).saturating_sub(res.max_k().as_usize());
326    let cnv_offset = b.max_k().as_usize() + res_offset;
327
328    Ok((
329        checked_log_budget_sub("mul", res_log_budget, res_offset)?,
330        res_log_delta,
331        cnv_offset,
332    ))
333}