poulpy_core/layouts/prepared/
ggsw.rs1use poulpy_hal::{
2 api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero},
3 layouts::{Backend, Data, HostDataRef, Module, ScratchArena, VmpPMat, VmpPMatToBackendMut, VmpPMatToBackendRef},
4};
5
6use crate::layouts::{
7 Base2K, Degree, Dnum, Dsize, GGSWInfos, GGSWToBackendRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10#[derive(PartialEq, Eq)]
16pub struct GGSWPrepared<D: Data, B: Backend> {
17 pub(crate) data: VmpPMat<D, B>,
18 pub(crate) base2k: Base2K,
19 pub(crate) dsize: Dsize,
20}
21
22pub type GGSWPreparedBackendRef<'a, B> = GGSWPrepared<<B as Backend>::BufRef<'a>, B>;
23pub type GGSWPreparedBackendMut<'a, B> = GGSWPrepared<<B as Backend>::BufMut<'a>, B>;
24
25impl<D: Data, B: Backend> LWEInfos for GGSWPrepared<D, B> {
26 fn n(&self) -> Degree {
27 Degree(self.data.n() as u32)
28 }
29
30 fn base2k(&self) -> Base2K {
31 self.base2k
32 }
33
34 fn size(&self) -> usize {
35 self.data.size()
36 }
37}
38
39impl<D: Data, B: Backend> LWEInfos for &GGSWPrepared<D, B> {
40 fn n(&self) -> Degree {
41 Degree(self.data.n() as u32)
42 }
43
44 fn base2k(&self) -> Base2K {
45 self.base2k
46 }
47
48 fn size(&self) -> usize {
49 self.data.size()
50 }
51}
52
53impl<D: Data, B: Backend> GLWEInfos for GGSWPrepared<D, B> {
54 fn rank(&self) -> Rank {
55 Rank(self.data.cols_out() as u32 - 1)
56 }
57}
58
59impl<D: Data, B: Backend> GLWEInfos for &GGSWPrepared<D, B> {
60 fn rank(&self) -> Rank {
61 Rank(self.data.cols_out() as u32 - 1)
62 }
63}
64
65impl<D: Data, B: Backend> GGSWInfos for GGSWPrepared<D, B> {
66 fn dsize(&self) -> Dsize {
67 self.dsize
68 }
69
70 fn dnum(&self) -> Dnum {
71 Dnum(self.data.rows() as u32)
72 }
73}
74
75impl<D: Data, B: Backend> GGSWInfos for &GGSWPrepared<D, B> {
76 fn dsize(&self) -> Dsize {
77 self.dsize
78 }
79
80 fn dnum(&self) -> Dnum {
81 Dnum(self.data.rows() as u32)
82 }
83}
84
85pub trait GGSWPreparedFactory<B: Backend>
87where
88 Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>,
89{
90 fn ggsw_prepared_alloc(
92 &self,
93 base2k: Base2K,
94 k: TorusPrecision,
95 dnum: Dnum,
96 dsize: Dsize,
97 rank: Rank,
98 ) -> GGSWPrepared<B::OwnedBuf, B> {
99 let size: usize = k.0.div_ceil(base2k.0) as usize;
100 debug_assert!(
101 size as u32 > dsize.0,
102 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
103 dsize.0
104 );
105
106 assert!(
107 dnum.0 * dsize.0 <= size as u32,
108 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
109 dnum.0,
110 dsize.0,
111 );
112
113 GGSWPrepared {
114 data: self.vmp_pmat_alloc(
115 dnum.into(),
116 (rank + 1).into(),
117 (rank + 1).into(),
118 k.0.div_ceil(base2k.0) as usize,
119 ),
120 base2k,
121 dsize,
122 }
123 }
124
125 fn ggsw_prepared_alloc_from_infos<A>(&self, infos: &A) -> GGSWPrepared<B::OwnedBuf, B>
126 where
127 A: GGSWInfos,
128 {
129 assert_eq!(self.ring_degree(), infos.n());
130 self.ggsw_prepared_alloc(infos.base2k(), infos.max_k(), infos.dnum(), infos.dsize(), infos.rank())
131 }
132
133 fn ggsw_prepared_bytes_of(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize {
134 let size: usize = k.0.div_ceil(base2k.0) as usize;
135 debug_assert!(
136 size as u32 > dsize.0,
137 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
138 dsize.0
139 );
140
141 assert!(
142 dnum.0 * dsize.0 <= size as u32,
143 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
144 dnum.0,
145 dsize.0,
146 );
147
148 self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size)
149 }
150
151 fn ggsw_prepared_bytes_of_from_infos<A>(&self, infos: &A) -> usize
152 where
153 A: GGSWInfos,
154 {
155 assert_eq!(self.ring_degree(), infos.n());
156 self.ggsw_prepared_bytes_of(infos.base2k(), infos.max_k(), infos.dnum(), infos.dsize(), infos.rank())
157 }
158
159 fn ggsw_prepare_tmp_bytes<A>(&self, infos: &A) -> usize
160 where
161 A: GGSWInfos,
162 {
163 assert_eq!(self.ring_degree(), infos.n());
164 let lvl_0: usize = self.vmp_prepare_tmp_bytes(
165 infos.dnum().into(),
166 (infos.rank() + 1).into(),
167 (infos.rank() + 1).into(),
168 infos.size(),
169 );
170 lvl_0
171 }
172 fn ggsw_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, B>)
173 where
174 R: GGSWPreparedToBackendMut<B>,
175 O: GGSWToBackendRef<B>,
176 {
177 let mut res = res.to_backend_mut();
178 let other = other.to_backend_ref();
179 assert_eq!(res.n(), self.ring_degree());
180 assert_eq!(other.n(), self.ring_degree());
181 assert_eq!(res.base2k, other.base2k);
182 assert_eq!(res.dsize, other.dsize);
183 assert!(
184 scratch.available() >= self.ggsw_prepare_tmp_bytes(&res),
185 "scratch.available(): {} < GGSWPreparedFactory::ggsw_prepare_tmp_bytes: {}",
186 scratch.available(),
187 self.ggsw_prepare_tmp_bytes(&res)
188 );
189 self.vmp_prepare(&mut res.data, &other.data, scratch);
190 }
191
192 fn ggsw_zero<R>(&self, res: &mut R)
193 where
194 R: GGSWPreparedToBackendMut<B>,
195 {
196 let mut res = res.to_backend_mut();
197 self.vmp_zero(&mut res.data);
198 }
199}
200
201impl<B: Backend> GGSWPreparedFactory<B> for Module<B> where
202 Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>
203{
204}
205
206impl<D: HostDataRef, B: Backend> GGSWPrepared<D, B> {
209 pub fn data(&self) -> &VmpPMat<D, B> {
210 &self.data
211 }
212}
213
214pub trait GGSWPreparedToBackendRef<B: Backend> {
219 fn to_backend_ref(&self) -> GGSWPreparedBackendRef<'_, B>;
220}
221
222impl<B: Backend> GGSWPreparedToBackendRef<B> for GGSWPrepared<B::OwnedBuf, B> {
223 fn to_backend_ref(&self) -> GGSWPreparedBackendRef<'_, B> {
224 GGSWPrepared {
225 base2k: self.base2k,
226 dsize: self.dsize,
227 data: self.data.to_backend_ref(),
228 }
229 }
230}
231
232pub trait GGSWPreparedToBackendMut<B: Backend> {
233 fn to_backend_mut(&mut self) -> GGSWPreparedBackendMut<'_, B>;
234}
235
236impl<B: Backend> GGSWPreparedToBackendMut<B> for GGSWPrepared<B::OwnedBuf, B> {
237 fn to_backend_mut(&mut self) -> GGSWPreparedBackendMut<'_, B> {
238 GGSWPrepared {
239 base2k: self.base2k,
240 dsize: self.dsize,
241 data: self.data.to_backend_mut(),
242 }
243 }
244}