1use anyhow::Result;
2use poulpy_core::{
3 GLWENormalize, GLWEShift, GLWESub,
4 layouts::{GLWEToBackendMut, LWEInfos},
5};
6use poulpy_hal::{
7 api::{VecZnxRshSubBackend, VecZnxRshSubCoeffIntoBackend, VecZnxRshTmpBytes},
8 layouts::{Backend, ScratchArena},
9};
10
11use crate::{
12 CKKSInfos, GLWEToBackendRef, SetCKKSInfos, checked_log_budget_sub, ckks_offset_binary, ckks_offset_unary,
13 default::add::ckks_one_pt, leveled::default::CKKSPlaintextDefault,
14};
15
16pub trait CKKSSubDefault<BE: Backend> {
17 fn ckks_sub_tmp_bytes_default(&self) -> usize
18 where
19 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
20 {
21 self.glwe_shift_tmp_bytes()
22 .max(self.vec_znx_rsh_tmp_bytes())
23 .max(self.glwe_normalize_tmp_bytes())
24 }
25
26 fn ckks_sub_pt_vec_tmp_bytes_default(&self) -> usize
27 where
28 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
29 {
30 self.ckks_sub_tmp_bytes_default()
31 }
32
33 fn ckks_sub_pt_const_tmp_bytes_default(&self) -> usize
34 where
35 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshTmpBytes,
36 {
37 self.glwe_shift_tmp_bytes()
38 .max(self.glwe_normalize_tmp_bytes())
39 .max(self.vec_znx_rsh_tmp_bytes())
40 }
41
42 fn ckks_sub_into_default<Dst, A, B>(&self, dst: &mut Dst, a: &A, b: &B, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
43 where
44 Self: GLWESub<BE> + GLWEShift<BE> + GLWENormalize<BE>,
45 Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
46 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
47 B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
48 {
49 self.ckks_sub_into_unsafe_default(dst, a, b, scratch)?;
50 self.glwe_normalize_assign(dst, scratch);
51 Ok(())
52 }
53
54 fn ckks_sub_into_unsafe_default<Dst, A, B>(
55 &self,
56 dst: &mut Dst,
57 a: &A,
58 b: &B,
59 scratch: &mut ScratchArena<'_, BE>,
60 ) -> Result<()>
61 where
62 Self: GLWESub<BE> + GLWEShift<BE>,
63 Dst: GLWEToBackendMut<BE> + LWEInfos + SetCKKSInfos + CKKSInfos,
64 A: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
65 B: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
66 {
67 let offset = ckks_offset_binary(dst, a, b);
68
69 if offset == 0 && a.log_budget() == b.log_budget() {
70 self.glwe_sub(dst, a, b);
71 } else if a.log_budget() <= b.log_budget() {
72 self.glwe_lsh(dst, a, offset, scratch);
73 self.glwe_lsh_sub(dst, b, b.log_budget() - a.log_budget() + offset, scratch);
74 } else {
75 self.glwe_lsh(dst, a, a.log_budget() - b.log_budget() + offset, scratch);
76 self.glwe_lsh_sub(dst, b, offset, scratch);
77 }
78
79 let log_budget = checked_log_budget_sub("sub", a.log_budget().min(b.log_budget()), offset)?;
80 dst.set_log_delta(a.log_delta().min(b.log_delta()));
81 dst.set_log_budget(log_budget);
82 Ok(())
83 }
84
85 fn ckks_sub_assign_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
86 where
87 Self: GLWESub<BE> + GLWEShift<BE> + GLWENormalize<BE>,
88 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
89 A: GLWEToBackendRef<BE> + CKKSInfos,
90 {
91 self.ckks_sub_assign_unsafe_default(dst, a, scratch)?;
92 self.glwe_normalize_assign(dst, scratch);
93 Ok(())
94 }
95
96 fn ckks_sub_assign_unsafe_default<Dst, A>(&self, dst: &mut Dst, a: &A, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
97 where
98 Self: GLWESub<BE> + GLWEShift<BE>,
99 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos,
100 A: GLWEToBackendRef<BE> + CKKSInfos,
101 {
102 let dst_log_budget = dst.log_budget();
103
104 if dst_log_budget < a.log_budget() {
105 self.glwe_lsh_sub(dst, a, a.log_budget() - dst_log_budget, scratch);
106 } else if dst_log_budget > a.log_budget() {
107 self.glwe_lsh_assign(dst, dst_log_budget - a.log_budget(), scratch);
108 self.glwe_sub_assign(dst, a);
109 } else {
110 self.glwe_sub_assign(dst, a);
111 }
112
113 dst.set_log_budget(dst_log_budget.min(a.log_budget()));
114 dst.set_log_delta(dst.log_delta().min(a.log_delta()));
115 Ok(())
116 }
117
118 fn ckks_sub_one_assign_default<Dst>(&self, dst: &mut Dst, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
119 where
120 Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE> + GLWENormalize<BE>,
121 Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
122 {
123 let one = ckks_one_pt::<BE>(dst.base2k())?;
124 self.ckks_sub_pt_const_assign_default(dst, 0, &one, 0, scratch)
125 }
126
127 fn ckks_sub_pt_vec_into_default<Dst, A, P>(
128 &self,
129 dst: &mut Dst,
130 a: &A,
131 pt: &P,
132 scratch: &mut ScratchArena<'_, BE>,
133 ) -> Result<()>
134 where
135 Self: VecZnxRshSubBackend<BE> + GLWEShift<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
136 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
137 A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
138 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
139 {
140 self.ckks_sub_pt_vec_into_unsafe_default(dst, a, pt, scratch)?;
141 self.glwe_normalize_assign(dst, scratch);
142 Ok(())
143 }
144
145 fn ckks_sub_pt_vec_into_unsafe_default<Dst, A, P>(
146 &self,
147 dst: &mut Dst,
148 a: &A,
149 pt: &P,
150 scratch: &mut ScratchArena<'_, BE>,
151 ) -> Result<()>
152 where
153 Self: VecZnxRshSubBackend<BE> + GLWEShift<BE> + CKKSPlaintextDefault<BE>,
154 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
155 A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
156 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
157 {
158 let offset = ckks_offset_unary(dst, a);
159 self.glwe_lsh(dst, a, offset, scratch);
160 dst.set_meta(a.meta());
161 dst.set_log_budget(checked_log_budget_sub("sub_pt_vec", a.log_budget(), offset)?);
162 self.ckks_sub_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
163 Ok(())
164 }
165
166 fn ckks_sub_pt_vec_assign_default<Dst, P>(&self, dst: &mut Dst, pt: &P, scratch: &mut ScratchArena<'_, BE>) -> Result<()>
167 where
168 Self: VecZnxRshSubBackend<BE> + GLWENormalize<BE> + CKKSPlaintextDefault<BE>,
169 Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
170 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
171 {
172 self.ckks_sub_pt_vec_assign_unsafe_default(dst, pt, scratch)?;
173 self.glwe_normalize_assign(dst, scratch);
174 Ok(())
175 }
176
177 fn ckks_sub_pt_vec_assign_unsafe_default<Dst, P>(
178 &self,
179 dst: &mut Dst,
180 pt: &P,
181 scratch: &mut ScratchArena<'_, BE>,
182 ) -> Result<()>
183 where
184 Self: VecZnxRshSubBackend<BE> + CKKSPlaintextDefault<BE>,
185 Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
186 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
187 {
188 CKKSPlaintextDefault::ckks_sub_pt_vec_into_default(self, dst, pt, scratch)?;
189 Ok(())
190 }
191
192 fn ckks_sub_pt_const_into_default<Dst, A, P>(
193 &self,
194 dst: &mut Dst,
195 a: &A,
196 dst_coeff: usize,
197 cst: &P,
198 const_coeff: usize,
199 scratch: &mut ScratchArena<'_, BE>,
200 ) -> Result<()>
201 where
202 Self: GLWEShift<BE> + GLWENormalize<BE> + VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
203 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
204 A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
205 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
206 {
207 self.ckks_sub_pt_const_into_unsafe_default(dst, a, dst_coeff, cst, const_coeff, scratch)?;
208 self.glwe_normalize_assign(dst, scratch);
209 Ok(())
210 }
211
212 fn ckks_sub_pt_const_into_unsafe_default<Dst, A, P>(
213 &self,
214 dst: &mut Dst,
215 a: &A,
216 dst_coeff: usize,
217 cst: &P,
218 const_coeff: usize,
219 scratch: &mut ScratchArena<'_, BE>,
220 ) -> Result<()>
221 where
222 Self: GLWEShift<BE> + VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
223 Dst: GLWEToBackendMut<BE> + CKKSInfos + SetCKKSInfos + LWEInfos,
224 A: GLWEToBackendRef<BE> + CKKSInfos + LWEInfos,
225 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
226 {
227 let offset = ckks_offset_unary(dst, a);
228 self.glwe_lsh(dst, a, offset, scratch);
229 dst.set_meta(a.meta());
230 dst.set_log_budget(checked_log_budget_sub("sub_pt_const", a.log_budget(), offset)?);
231 self.ckks_sub_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)
232 }
233
234 fn ckks_sub_pt_const_assign_default<Dst, P>(
235 &self,
236 dst: &mut Dst,
237 dst_coeff: usize,
238 cst: &P,
239 const_coeff: usize,
240 scratch: &mut ScratchArena<'_, BE>,
241 ) -> Result<()>
242 where
243 Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE> + GLWENormalize<BE>,
244 Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
245 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
246 {
247 self.ckks_sub_pt_const_assign_unsafe_default(dst, dst_coeff, cst, const_coeff, scratch)?;
248 self.glwe_normalize_assign(dst, scratch);
249 Ok(())
250 }
251
252 fn ckks_sub_pt_const_assign_unsafe_default<Dst, P>(
253 &self,
254 dst: &mut Dst,
255 dst_coeff: usize,
256 cst: &P,
257 const_coeff: usize,
258 scratch: &mut ScratchArena<'_, BE>,
259 ) -> Result<()>
260 where
261 Self: VecZnxRshSubCoeffIntoBackend<BE> + CKKSPlaintextDefault<BE>,
262 Dst: GLWEToBackendMut<BE> + CKKSInfos + LWEInfos,
263 P: GLWEToBackendRef<BE> + LWEInfos + CKKSInfos,
264 {
265 CKKSPlaintextDefault::ckks_sub_pt_const_into_default(self, dst, dst_coeff, cst, const_coeff, scratch)?;
266 Ok(())
267 }
268}