Skip to main content

poulpy_ckks/default/
sub.rs

1use anyhow::Result;
2use poulpy_core::{
3    GLWENormalize, GLWEShift, GLWESub,
4    layouts::{GLWEToBackendMut, LWEInfos},
5};
6use poulpy_hal::{
7    api::{VecZnxRshSubBackend, VecZnxRshSubCoeffIntoBackend, VecZnxRshTmpBytes},
8    layouts::{Backend, ScratchArena},
9};
10
11use crate::{
12    CKKSInfos, GLWEToBackendRef, SetCKKSInfos, checked_log_budget_sub, ckks_offset_binary, ckks_offset_unary,
13    default::add::ckks_one_pt, leveled::default::CKKSPlaintextDefault,
14};
15
16pub trait CKKSSubDefault<BE: Backend> {
17    fn ckks_sub_tmp_bytes_default(&self) -> usize
18    where
19        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
20    {
21        self.glwe_shift_tmp_bytes()
22            .max(self.vec_znx_rsh_tmp_bytes())
23            .max(self.glwe_normalize_tmp_bytes())
24    }
25
26    fn ckks_sub_pt_vec_tmp_bytes_default(&self) -> usize
27    where
28        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
29    {
30        self.ckks_sub_tmp_bytes_default()
31    }
32
33    fn ckks_sub_pt_const_tmp_bytes_default(&self) -> usize
34    where
35        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
36    {
37        self.glwe_shift_tmp_bytes()
38            .max(self.glwe_normalize_tmp_bytes())
39            .max(self.vec_znx_rsh_tmp_bytes())
40    }
41
42    fn ckks_sub_into_default<Dst, A, B>(&self, dst: &mut Dst, a: &A, b: &B, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
43    where
44        Self: GLWESub<BE> + GLWEShift<BE> + GLWENormalize<BE>,
45        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
46        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
47        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
48    {
49        self.ckks_sub_into_unsafe_default(dst, a, b, scratch)?;
50        self.glwe_normalize_assign(dst, scratch);
51        Ok(())
52    }
53
54    fn ckks_sub_into_unsafe_default<Dst, A, B>(
55        &self,
56        dst: &mut Dst,
57        a: &A,
58        b: &B,
59        scratch: &mut ScratchArena<'_, BE>,
60    ) -> Result<()>
61    where
62        Self: GLWESub<BE> + GLWEShift<BE>,
63        Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
64        A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
65        B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
66    {
67        let offset = ckks_offset_binary(dst, a, b);
68
69        if offset == 0 && a.log_budget() == b.log_budget() {
70            self.glwe_sub(dst, a, b);
71        } else if a.log_budget() <= b.log_budget() {
72            self.glwe_lsh(dst, a, offset, scratch);
73            self.glwe_lsh_sub(dst, b, b.log_budget() - a.log_budget() + offset, scratch);
74        } else {
75            self.glwe_lsh(dst, a, a.log_budget() - b.log_budget() + offset, scratch);
76            self.glwe_lsh_sub(dst, b, offset, scratch);
77        }
78
79        let log_budget = checked_log_budget_sub("sub", a.log_budget().min(b.log_budget()), offset)?;
80        dst.set_log_delta(a.log_delta().min(b.log_delta()));
81        dst.set_log_budget(log_budget);
82        Ok(())
83    }
84
85    fn ckks_sub_assign_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
86    where
87        Self: GLWESub<BE> + GLWEShift<BE> + GLWENormalize<BE>,
88        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
89        A: GLWEToBackendRef<BE> + CKKSInfos,
90    {
91        self.ckks_sub_assign_unsafe_default(dst, a, scratch)?;
92        self.glwe_normalize_assign(dst, scratch);
93        Ok(())
94    }
95
96    fn ckks_sub_assign_unsafe_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
97    where
98        Self: GLWESub<BE> + GLWEShift<BE>,
99        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
100        A: GLWEToBackendRef<BE> + CKKSInfos,
101    {
102        let dst_log_budget = dst.log_budget();
103
104        if dst_log_budget < a.log_budget() {
105            self.glwe_lsh_sub(dst, a, a.log_budget() - dst_log_budget, scratch);
106        } else if dst_log_budget > a.log_budget() {
107            self.glwe_lsh_assign(dst, dst_log_budget - a.log_budget(), scratch);
108            self.glwe_sub_assign(dst, a);
109        } else {
110            self.glwe_sub_assign(dst, a);
111        }
112
113        dst.set_log_budget(dst_log_budget.min(a.log_budget()));
114        dst.set_log_delta(dst.log_delta().min(a.log_delta()));
115        Ok(())
116    }
117
118    fn ckks_sub_one_assign_default<Dst>(&self, dst: &mut Dst, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
119    where
120        Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE> + GLWENormalize<BE>,
121        Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
122    {
123        let one = ckks_one_pt::<BE>(dst.base2k())?;
124        self.ckks_sub_pt_const_assign_default(dst, 0, &one, 0, scratch)
125    }
126
127    fn ckks_sub_pt_vec_into_default<Dst, A, P>(
128        &self,
129        dst: &mut Dst,
130        a: &A,
131        pt: &P,
132        scratch: &mut ScratchArena<'_, BE>,
133    ) -> Result<()>
134    where
135        Self: VecZnxRshSubBackend<BE> + GLWEShift<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
136        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
137        A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
138        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
139    {
140        self.ckks_sub_pt_vec_into_unsafe_default(dst, a, pt, scratch)?;
141        self.glwe_normalize_assign(dst, scratch);
142        Ok(())
143    }
144
145    fn ckks_sub_pt_vec_into_unsafe_default<Dst, A, P>(
146        &self,
147        dst: &mut Dst,
148        a: &A,
149        pt: &P,
150        scratch: &mut ScratchArena<'_, BE>,
151    ) -> Result<()>
152    where
153        Self: VecZnxRshSubBackend<BE> + GLWEShift<BE> + CKKSPlaintextDefault<BE>,
154        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
155        A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
156        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
157    {
158        let offset = ckks_offset_unary(dst, a);
159        self.glwe_lsh(dst, a, offset, scratch);
160        dst.set_meta(a.meta());
161        dst.set_log_budget(checked_log_budget_sub("sub_pt_vec", a.log_budget(), offset)?);
162        self.ckks_sub_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
163        Ok(())
164    }
165
166    fn ckks_sub_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
167    where
168        Self: VecZnxRshSubBackend<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
169        Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
170        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
171    {
172        self.ckks_sub_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
173        self.glwe_normalize_assign(dst, scratch);
174        Ok(())
175    }
176
177    fn ckks_sub_pt_vec_assign_unsafe_default<Dst, P>(
178        &self,
179        dst: &mut Dst,
180        pt: &P,
181        scratch: &mut ScratchArena<'_, BE>,
182    ) -> Result<()>
183    where
184        Self: VecZnxRshSubBackend<BE> + CKKSPlaintextDefault<BE>,
185        Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
186        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
187    {
188        CKKSPlaintextDefault::ckks_sub_pt_vec_into_default(self, dst, pt, scratch)?;
189        Ok(())
190    }
191
192    fn ckks_sub_pt_const_into_default<Dst, A, P>(
193        &self,
194        dst: &mut Dst,
195        a: &A,
196        dst_coeff: usize,
197        cst: &P,
198        const_coeff: usize,
199        scratch: &mut ScratchArena<'_, BE>,
200    ) -> Result<()>
201    where
202        Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
203        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
204        A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
205        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
206    {
207        self.ckks_sub_pt_const_into_unsafe_default(dst, a, dst_coeff, cst, const_coeff, scratch)?;
208        self.glwe_normalize_assign(dst, scratch);
209        Ok(())
210    }
211
212    fn ckks_sub_pt_const_into_unsafe_default<Dst, A, P>(
213        &self,
214        dst: &mut Dst,
215        a: &A,
216        dst_coeff: usize,
217        cst: &P,
218        const_coeff: usize,
219        scratch: &mut ScratchArena<'_, BE>,
220    ) -> Result<()>
221    where
222        Self: GLWEShift<BE> + VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
223        Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
224        A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
225        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
226    {
227        let offset = ckks_offset_unary(dst, a);
228        self.glwe_lsh(dst, a, offset, scratch);
229        dst.set_meta(a.meta());
230        dst.set_log_budget(checked_log_budget_sub("sub_pt_const", a.log_budget(), offset)?);
231        self.ckks_sub_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)
232    }
233
234    fn ckks_sub_pt_const_assign_default<Dst, P>(
235        &self,
236        dst: &mut Dst,
237        dst_coeff: usize,
238        cst: &P,
239        const_coeff: usize,
240        scratch: &mut ScratchArena<'_, BE>,
241    ) -> Result<()>
242    where
243        Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE> + GLWENormalize<BE>,
244        Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
245        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
246    {
247        self.ckks_sub_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)?;
248        self.glwe_normalize_assign(dst, scratch);
249        Ok(())
250    }
251
252    fn ckks_sub_pt_const_assign_unsafe_default<Dst, P>(
253        &self,
254        dst: &mut Dst,
255        dst_coeff: usize,
256        cst: &P,
257        const_coeff: usize,
258        scratch: &mut ScratchArena<'_, BE>,
259    ) -> Result<()>
260    where
261        Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
262        Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
263        P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
264    {
265        CKKSPlaintextDefault::ckks_sub_pt_const_into_default(self, dst, dst_coeff, cst, const_coeff, scratch)?;
266        Ok(())
267    }
268}