poulpy_core/layouts/prepared/
gglwe.rs1use poulpy_hal::{
2 api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
3 layouts::{Backend, Data, Module, ScratchArena, VmpPMat, VmpPMatToBackendMut, VmpPMatToBackendRef},
4};
5
6use crate::layouts::{
7 Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEToBackendRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10#[derive(PartialEq, Eq)]
19pub struct GGLWEPrepared<D: Data, B: Backend> {
20 pub(crate) data: VmpPMat<D, B>,
21 pub(crate) base2k: Base2K,
22 pub(crate) dsize: Dsize,
23}
24
25pub type GGLWEPreparedBackendRef<'a, B> = GGLWEPrepared<<B as Backend>::BufRef<'a>, B>;
26pub type GGLWEPreparedBackendMut<'a, B> = GGLWEPrepared<<B as Backend>::BufMut<'a>, B>;
27
28impl<D: Data, B: Backend> LWEInfos for GGLWEPrepared<D, B> {
30 fn n(&self) -> Degree {
31 Degree(self.data.n() as u32)
32 }
33
34 fn base2k(&self) -> Base2K {
35 self.base2k
36 }
37
38 fn size(&self) -> usize {
39 self.data.size()
40 }
41}
42
43impl<D: Data, B: Backend> GLWEInfos for GGLWEPrepared<D, B> {
45 fn rank(&self) -> Rank {
46 self.rank_out()
47 }
48}
49
50impl<D: Data, B: Backend> GGLWEInfos for GGLWEPrepared<D, B> {
52 fn rank_in(&self) -> Rank {
53 Rank(self.data.cols_in() as u32)
54 }
55
56 fn rank_out(&self) -> Rank {
57 Rank(self.data.cols_out() as u32 - 1)
58 }
59
60 fn dsize(&self) -> Dsize {
61 self.dsize
62 }
63
64 fn dnum(&self) -> Dnum {
65 Dnum(self.data.rows() as u32)
66 }
67}
68
69pub trait GGLWEPreparedFactory<BE: Backend>
74where
75 Self: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes,
76{
77 fn gglwe_prepared_alloc(
81 &self,
82 base2k: Base2K,
83 k: TorusPrecision,
84 rank_in: Rank,
85 rank_out: Rank,
86 dnum: Dnum,
87 dsize: Dsize,
88 ) -> GGLWEPrepared<BE::OwnedBuf, BE> {
89 let size: usize = k.0.div_ceil(base2k.0) as usize;
90 debug_assert!(
91 size as u32 > dsize.0,
92 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
93 dsize.0
94 );
95
96 assert!(
97 dnum.0 * dsize.0 <= size as u32,
98 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
99 dnum.0,
100 dsize.0,
101 );
102
103 GGLWEPrepared {
104 data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size),
105 base2k,
106 dsize,
107 }
108 }
109
110 fn gglwe_prepared_alloc_from_infos<A>(&self, infos: &A) -> GGLWEPrepared<BE::OwnedBuf, BE>
112 where
113 A: GGLWEInfos,
114 {
115 assert_eq!(self.ring_degree(), infos.n());
116 self.gglwe_prepared_alloc(
117 infos.base2k(),
118 infos.max_k(),
119 infos.rank_in(),
120 infos.rank_out(),
121 infos.dnum(),
122 infos.dsize(),
123 )
124 }
125
126 fn gglwe_prepared_bytes_of(
128 &self,
129 base2k: Base2K,
130 k: TorusPrecision,
131 rank_in: Rank,
132 rank_out: Rank,
133 dnum: Dnum,
134 dsize: Dsize,
135 ) -> usize {
136 let size: usize = k.0.div_ceil(base2k.0) as usize;
137 debug_assert!(
138 size as u32 > dsize.0,
139 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
140 dsize.0
141 );
142
143 assert!(
144 dnum.0 * dsize.0 <= size as u32,
145 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
146 dnum.0,
147 dsize.0,
148 );
149
150 self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size)
151 }
152
153 fn gglwe_prepared_bytes_of_from_infos<A>(&self, infos: &A) -> usize
155 where
156 A: GGLWEInfos,
157 {
158 assert_eq!(self.ring_degree(), infos.n());
159 self.gglwe_prepared_bytes_of(
160 infos.base2k(),
161 infos.max_k(),
162 infos.rank_in(),
163 infos.rank_out(),
164 infos.dnum(),
165 infos.dsize(),
166 )
167 }
168
169 fn gglwe_prepare_tmp_bytes<A>(&self, infos: &A) -> usize
171 where
172 A: GGLWEInfos,
173 {
174 let lvl_0: usize = self.vmp_prepare_tmp_bytes(
175 infos.dnum().into(),
176 infos.rank_in().into(),
177 (infos.rank() + 1).into(),
178 infos.size(),
179 );
180 lvl_0
181 }
182
183 fn gglwe_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, BE>)
187 where
188 R: GGLWEPreparedToBackendMut<BE>,
189 O: GGLWEToBackendRef<BE>,
190 {
191 let mut res = res.to_backend_mut();
192 let other = other.to_backend_ref();
193
194 assert_eq!(res.n(), self.ring_degree());
195 assert_eq!(other.n(), self.ring_degree());
196 assert_eq!(res.base2k, other.base2k);
197 assert_eq!(res.size(), other.size());
198 assert_eq!(res.dsize, other.dsize);
199 assert!(
200 scratch.available() >= self.gglwe_prepare_tmp_bytes(&res),
201 "scratch.available(): {} < GGLWEPreparedFactory::gglwe_prepare_tmp_bytes: {}",
202 scratch.available(),
203 self.gglwe_prepare_tmp_bytes(&res)
204 );
205 self.vmp_prepare(&mut res.data, &other.data, scratch);
206 }
207}
208
209impl<BE: Backend> GGLWEPreparedFactory<BE> for Module<BE> where
210 Module<BE>: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes
211{
212}
213
214pub trait GGLWEPreparedToBackendRef<B: Backend> {
219 fn to_backend_ref(&self) -> GGLWEPreparedBackendRef<'_, B>;
220}
221
222impl<B: Backend> GGLWEPreparedToBackendRef<B> for GGLWEPrepared<B::OwnedBuf, B> {
223 fn to_backend_ref(&self) -> GGLWEPreparedBackendRef<'_, B> {
224 GGLWEPrepared {
225 base2k: self.base2k,
226 dsize: self.dsize,
227 data: self.data.to_backend_ref(),
228 }
229 }
230}
231
232pub trait GGLWEPreparedToBackendMut<B: Backend> {
233 fn to_backend_mut(&mut self) -> GGLWEPreparedBackendMut<'_, B>;
234}
235
236impl<B: Backend> GGLWEPreparedToBackendMut<B> for GGLWEPrepared<B::OwnedBuf, B> {
237 fn to_backend_mut(&mut self) -> GGLWEPreparedBackendMut<'_, B> {
238 GGLWEPrepared {
239 base2k: self.base2k,
240 dsize: self.dsize,
241 data: self.data.to_backend_mut(),
242 }
243 }
244}