1use poulpy_hal::{
2 api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
3 layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos},
4};
5
6use crate::layouts::{
7 Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10#[derive(PartialEq, Eq)]
11pub struct GGSWPrepared<D: Data, B: Backend> {
12 pub(crate) data: VmpPMat<D, B>,
13 pub(crate) k: TorusPrecision,
14 pub(crate) base2k: Base2K,
15 pub(crate) dsize: Dsize,
16}
17
18impl<D: Data, B: Backend> LWEInfos for GGSWPrepared<D, B> {
19 fn n(&self) -> Degree {
20 Degree(self.data.n() as u32)
21 }
22
23 fn base2k(&self) -> Base2K {
24 self.base2k
25 }
26
27 fn k(&self) -> TorusPrecision {
28 self.k
29 }
30
31 fn size(&self) -> usize {
32 self.data.size()
33 }
34}
35
36impl<D: Data, B: Backend> GLWEInfos for GGSWPrepared<D, B> {
37 fn rank(&self) -> Rank {
38 Rank(self.data.cols_out() as u32 - 1)
39 }
40}
41
42impl<D: Data, B: Backend> GGSWInfos for GGSWPrepared<D, B> {
43 fn dsize(&self) -> Dsize {
44 self.dsize
45 }
46
47 fn dnum(&self) -> Dnum {
48 Dnum(self.data.rows() as u32)
49 }
50}
51
52pub trait GGSWPreparedFactory<B: Backend>
53where
54 Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B>,
55{
56 fn alloc_ggsw_prepared(
57 &self,
58 base2k: Base2K,
59 k: TorusPrecision,
60 dnum: Dnum,
61 dsize: Dsize,
62 rank: Rank,
63 ) -> GGSWPrepared<Vec<u8>, B> {
64 let size: usize = k.0.div_ceil(base2k.0) as usize;
65 debug_assert!(
66 size as u32 > dsize.0,
67 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
68 dsize.0
69 );
70
71 assert!(
72 dnum.0 * dsize.0 <= size as u32,
73 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
74 dnum.0,
75 dsize.0,
76 );
77
78 GGSWPrepared {
79 data: self.vmp_pmat_alloc(
80 dnum.into(),
81 (rank + 1).into(),
82 (rank + 1).into(),
83 k.0.div_ceil(base2k.0) as usize,
84 ),
85 k,
86 base2k,
87 dsize,
88 }
89 }
90
91 fn alloc_ggsw_prepared_from_infos<A>(&self, infos: &A) -> GGSWPrepared<Vec<u8>, B>
92 where
93 A: GGSWInfos,
94 {
95 assert_eq!(self.ring_degree(), infos.n());
96 self.alloc_ggsw_prepared(
97 infos.base2k(),
98 infos.k(),
99 infos.dnum(),
100 infos.dsize(),
101 infos.rank(),
102 )
103 }
104
105 fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize {
106 let size: usize = k.0.div_ceil(base2k.0) as usize;
107 debug_assert!(
108 size as u32 > dsize.0,
109 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
110 dsize.0
111 );
112
113 assert!(
114 dnum.0 * dsize.0 <= size as u32,
115 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
116 dnum.0,
117 dsize.0,
118 );
119
120 self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size)
121 }
122
123 fn bytes_of_ggsw_prepared_from_infos<A>(&self, infos: &A) -> usize
124 where
125 A: GGSWInfos,
126 {
127 assert_eq!(self.ring_degree(), infos.n());
128 self.bytes_of_ggsw_prepared(
129 infos.base2k(),
130 infos.k(),
131 infos.dnum(),
132 infos.dsize(),
133 infos.rank(),
134 )
135 }
136
137 fn ggsw_prepare_tmp_bytes<A>(&self, infos: &A) -> usize
138 where
139 A: GGSWInfos,
140 {
141 assert_eq!(self.ring_degree(), infos.n());
142 self.vmp_prepare_tmp_bytes(
143 infos.dnum().into(),
144 (infos.rank() + 1).into(),
145 (infos.rank() + 1).into(),
146 infos.size(),
147 )
148 }
149 fn ggsw_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>)
150 where
151 R: GGSWPreparedToMut<B>,
152 O: GGSWToRef,
153 {
154 let mut res: GGSWPrepared<&mut [u8], B> = res.to_mut();
155 let other: GGSW<&[u8]> = other.to_ref();
156 assert_eq!(res.n(), self.ring_degree());
157 assert_eq!(other.n(), self.ring_degree());
158 assert_eq!(res.k, other.k);
159 assert_eq!(res.base2k, other.base2k);
160 assert_eq!(res.dsize, other.dsize);
161 self.vmp_prepare(&mut res.data, &other.data, scratch);
162 }
163}
164
165impl<B: Backend> GGSWPreparedFactory<B> for Module<B> where
166 Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B>
167{
168}
169
170impl<B: Backend> GGSWPrepared<Vec<u8>, B> {
171 pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
172 where
173 A: GGSWInfos,
174 M: GGSWPreparedFactory<B>,
175 {
176 module.alloc_ggsw_prepared_from_infos(infos)
177 }
178
179 pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self
180 where
181 M: GGSWPreparedFactory<B>,
182 {
183 module.alloc_ggsw_prepared(base2k, k, dnum, dsize, rank)
184 }
185
186 pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
187 where
188 A: GGSWInfos,
189 M: GGSWPreparedFactory<B>,
190 {
191 module.bytes_of_ggsw_prepared_from_infos(infos)
192 }
193
194 pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize
195 where
196 M: GGSWPreparedFactory<B>,
197 {
198 module.bytes_of_ggsw_prepared(base2k, k, dnum, dsize, rank)
199 }
200}
201
202impl<D: DataRef, B: Backend> GGSWPrepared<D, B> {
203 pub fn data(&self) -> &VmpPMat<D, B> {
204 &self.data
205 }
206}
207
208impl<B: Backend> GGSWPrepared<Vec<u8>, B> {
209 pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A) -> usize
210 where
211 A: GGSWInfos,
212 M: GGSWPreparedFactory<B>,
213 {
214 module.ggsw_prepare_tmp_bytes(infos)
215 }
216}
217
218impl<D: DataMut, B: Backend> GGSWPrepared<D, B> {
219 pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
220 where
221 O: GGSWToRef,
222 M: GGSWPreparedFactory<B>,
223 {
224 module.ggsw_prepare(self, other, scratch);
225 }
226}
227
228pub trait GGSWPreparedToMut<B: Backend> {
229 fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B>;
230}
231
232impl<D: DataMut, B: Backend> GGSWPreparedToMut<B> for GGSWPrepared<D, B> {
233 fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B> {
234 GGSWPrepared {
235 base2k: self.base2k,
236 k: self.k,
237 dsize: self.dsize,
238 data: self.data.to_mut(),
239 }
240 }
241}
242
243pub trait GGSWPreparedToRef<B: Backend> {
244 fn to_ref(&self) -> GGSWPrepared<&[u8], B>;
245}
246
247impl<D: DataRef, B: Backend> GGSWPreparedToRef<B> for GGSWPrepared<D, B> {
248 fn to_ref(&self) -> GGSWPrepared<&[u8], B> {
249 GGSWPrepared {
250 base2k: self.base2k,
251 k: self.k,
252 dsize: self.dsize,
253 data: self.data.to_ref(),
254 }
255 }
256}