Skip to main content

poulpy_ckks/default/
add.rs

1use anyhow::Result;
2use poulpy_core::{
3    GLWEAdd, GLWENormalize, GLWEShift,
4    layouts::{Base2K, GLWEPlaintext, GLWEToBackendMut, GLWEToBackendRef, LWEInfos},
5};
6use poulpy_hal::{
7    api::{VecZnxRshAddCoeffIntoBackend, VecZnxRshAddIntoBackend, VecZnxRshTmpBytes},
8    layouts::{Backend, ScratchArena, VecZnx},
9};
10
11use crate::{
12    CKKSInfos, CKKSMeta, SetCKKSInfos, checked_log_budget_sub, ckks_offset_binary, ckks_offset_unary,
13    default::CKKSPlaintextDefault,
14    layouts::{CKKSPlaintext, CKKSPlaintextVecHostCodec},
15};
16
17pub(crate) fn ckks_one_pt<BE>(base2k: Base2K) -> Result<CKKSPlaintext<BE::OwnedBuf>>
18where
19    BE: Backend,
20{
21    let meta = CKKSMeta {
22        log_delta: 1,
23        log_budget: 0,
24    };
25
26    let mut host_pt = CKKSPlaintext::from_inner(
27        GLWEPlaintext::alloc_with_meta(1usize.into(), base2k, meta.min_k(base2k)),
28        meta,
29    );
30    host_pt.encode_host_floats(&[1.0f64])?;
31
32    let shape = host_pt.inner.data.shape();
33    let backend_inner = GLWEPlaintext {
34        data: VecZnx::from_data_with_max_size(
35            BE::from_host_bytes(host_pt.inner.data.data.as_ref()),
36            shape.n(),
37            shape.cols(),
38            shape.size(),
39            shape.max_size(),
40        ),
41        base2k,
42    };
43    Ok(CKKSPlaintext::from_inner(backend_inner, meta))
44}
45
46pub trait CKKSAddDefault<BE: Backend> {
47    fn ckks_add_tmp_bytes_default(&self) -> usize
48    where
49        Self: GLWEShift<BE> + GLWENormalize<BE>,
50    {
51        self.glwe_shift_tmp_bytes().max(self.glwe_normalize_tmp_bytes())
52    }
53
54    fn ckks_add_pt_vec_tmp_bytes_default(&self) -> usize
55    where
56        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
57    {
58        self.glwe_shift_tmp_bytes()
59            .max(self.vec_znx_rsh_tmp_bytes())
60            .max(self.glwe_normalize_tmp_bytes())
61    }
62
63    fn ckks_add_pt_const_tmp_bytes_default(&self) -> usize
64    where
65        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
66    {
67        self.glwe_shift_tmp_bytes()
68            .max(self.glwe_normalize_tmp_bytes())
69            .max(self.vec_znx_rsh_tmp_bytes())
70    }
71
72    fn ckks_add_into_default<Dst, A, B>(&self, dst: &mut Dst, a: &A, b: &B, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
73    where
74        Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
75        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
76        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
77        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
78    {
79        self.ckks_add_into_unsafe_default(dst, a, b, scratch)?;
80        self.glwe_normalize_assign(dst, scratch);
81        Ok(())
82    }
83
84    fn ckks_add_into_unsafe_default<Dst, A, B>(
85        &self,
86        dst: &mut Dst,
87        a: &A,
88        b: &B,
89        scratch: &mut ScratchArena<'_, BE>,
90    ) -> Result<()>
91    where
92        Self: GLWEAdd<BE> + GLWEShift<BE>,
93        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
94        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
95        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
96    {
97        let offset = ckks_offset_binary(dst, a, b);
98
99        if offset == 0 && a.log_budget() == b.log_budget() {
100            self.glwe_add_into(dst, a, b);
101        } else if a.log_budget() <= b.log_budget() {
102            self.glwe_lsh(dst, a, offset, scratch);
103            self.glwe_lsh_add(dst, b, b.log_budget() - a.log_budget() + offset, scratch);
104        } else {
105            self.glwe_lsh(dst, b, offset, scratch);
106            self.glwe_lsh_add(dst, a, a.log_budget() - b.log_budget() + offset, scratch);
107        }
108
109        let log_budget = checked_log_budget_sub("add", a.log_budget().min(b.log_budget()), offset)?;
110        dst.set_log_delta(a.log_delta().min(b.log_delta()));
111        dst.set_log_budget(log_budget);
112        Ok(())
113    }
114
115    fn ckks_add_assign_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
116    where
117        Self: GLWEAdd<BE> + GLWEShift<BE> + GLWENormalize<BE>,
118        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
119        A: GLWEToBackendRef<BE> + CKKSInfos,
120    {
121        self.ckks_add_assign_unsafe_default(dst, a, scratch)?;
122        self.glwe_normalize_assign(dst, scratch);
123        Ok(())
124    }
125
126    fn ckks_add_assign_unsafe_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
127    where
128        Self: GLWEAdd<BE> + GLWEShift<BE>,
129        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
130        A: GLWEToBackendRef<BE> + CKKSInfos,
131    {
132        let dst_log_budget = dst.log_budget();
133
134        if dst_log_budget < a.log_budget() {
135            self.glwe_lsh_add(dst, a, a.log_budget() - dst_log_budget, scratch);
136        } else if dst_log_budget > a.log_budget() {
137            self.glwe_lsh_assign(dst, dst_log_budget - a.log_budget(), scratch);
138            self.glwe_add_assign(dst, a);
139        } else {
140            self.glwe_add_assign(dst, a);
141        }
142
143        dst.set_log_budget(dst_log_budget.min(a.log_budget()));
144        dst.set_log_delta(dst.log_delta().min(a.log_delta()));
145        Ok(())
146    }
147
148    fn ckks_add_one_assign_default<Dst>(&self, dst: &mut Dst, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
149    where
150        Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
151        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
152    {
153        let one = ckks_one_pt::<BE>(dst.base2k())?;
154        self.ckks_add_pt_const_assign_default(dst, 0, &one, 0, scratch)
155    }
156
157    fn ckks_add_pt_vec_into_default<Dst, A, P>(
158        &self,
159        dst: &mut Dst,
160        a: &A,
161        pt: &P,
162        scratch: &mut ScratchArena<'_, BE>,
163    ) -> Result<()>
164    where
165        Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
166        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
167        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
168        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
169    {
170        self.ckks_add_pt_vec_into_unsafe_default(dst, a, pt, scratch)?;
171        self.glwe_normalize_assign(dst, scratch);
172        Ok(())
173    }
174
175    fn ckks_add_pt_vec_into_unsafe_default<Dst, A, P>(
176        &self,
177        dst: &mut Dst,
178        a: &A,
179        pt: &P,
180        scratch: &mut ScratchArena<'_, BE>,
181    ) -> Result<()>
182    where
183        Self: VecZnxRshAddIntoBackend<BE> + GLWEShift<BE> + CKKSPlaintextDefault<BE>,
184        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
185        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
186        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
187    {
188        let offset = ckks_offset_unary(dst, a);
189        self.glwe_lsh(dst, a, offset, scratch);
190        dst.set_meta(a.meta());
191        dst.set_log_budget(checked_log_budget_sub("add_pt_vec", a.log_budget(), offset)?);
192        self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
193        Ok(())
194    }
195
196    fn ckks_add_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
197    where
198        Self: VecZnxRshAddIntoBackend<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
199        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
200        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
201    {
202        self.ckks_add_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
203        self.glwe_normalize_assign(dst, scratch);
204        Ok(())
205    }
206
207    fn ckks_add_pt_vec_assign_unsafe_default<Dst, P>(
208        &self,
209        dst: &mut Dst,
210        pt: &P,
211        scratch: &mut ScratchArena<'_, BE>,
212    ) -> Result<()>
213    where
214        Self: VecZnxRshAddIntoBackend<BE> + CKKSPlaintextDefault<BE>,
215        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
216        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
217    {
218        CKKSPlaintextDefault::ckks_add_pt_vec_into_default(self, dst, pt, scratch)?;
219        Ok(())
220    }
221
222    fn ckks_add_pt_const_into_default<Dst, A, P>(
223        &self,
224        dst: &mut Dst,
225        a: &A,
226        dst_coeff: usize,
227        cst: &P,
228        const_coeff: usize,
229        scratch: &mut ScratchArena<'_, BE>,
230    ) -> Result<()>
231    where
232        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
233        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
234        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
235        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
236    {
237        self.ckks_add_pt_const_into_unsafe_default(dst, a, dst_coeff, cst, const_coeff, scratch)?;
238        self.glwe_normalize_assign(dst, scratch);
239        Ok(())
240    }
241
242    fn ckks_add_pt_const_into_unsafe_default<Dst, A, P>(
243        &self,
244        dst: &mut Dst,
245        a: &A,
246        dst_coeff: usize,
247        cst: &P,
248        const_coeff: usize,
249        scratch: &mut ScratchArena<'_, BE>,
250    ) -> Result<()>
251    where
252        Self: GLWEShift<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
253        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos + SetCKKSInfos,
254        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
255        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
256    {
257        let offset = ckks_offset_unary(dst, a);
258        self.glwe_lsh(dst, a, offset, scratch);
259        dst.set_meta(a.meta());
260        dst.set_log_budget(checked_log_budget_sub("add_const", a.log_budget(), offset)?);
261        self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)
262    }
263
264    fn ckks_add_pt_const_assign_default<Dst, P>(
265        &self,
266        dst: &mut Dst,
267        dst_coeff: usize,
268        cst: &P,
269        const_coeff: usize,
270        scratch: &mut ScratchArena<'_, BE>,
271    ) -> Result<()>
272    where
273        Self: GLWENormalize<BE> + VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
274        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
275        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
276    {
277        self.ckks_add_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)?;
278        self.glwe_normalize_assign(dst, scratch);
279        Ok(())
280    }
281
282    fn ckks_add_pt_const_assign_unsafe_default<Dst, P>(
283        &self,
284        dst: &mut Dst,
285        dst_coeff: usize,
286        cst: &P,
287        const_coeff: usize,
288        scratch: &mut ScratchArena<'_, BE>,
289    ) -> Result<()>
290    where
291        Self: VecZnxRshAddCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
292        Dst: GLWEToBackendMut<BE> + LWEInfos + CKKSInfos,
293        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
294    {
295        CKKSPlaintextDefault::ckks_add_pt_const_into_default(self, dst, dst_coeff, cst, const_coeff, scratch)?;
296        Ok(())
297    }
298}