poulpy_core/layouts/prepared/
glwe_pk.rs1use poulpy_hal::{
2 api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply},
3 layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos},
4 oep::VecZnxDftAllocBytesImpl,
5};
6
7use crate::{
8 dist::Distribution,
9 layouts::{
10 Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision,
11 prepared::{Prepare, PrepareAlloc},
12 },
13};
14
15#[derive(PartialEq, Eq)]
16pub struct GLWEPublicKeyPrepared<D: Data, B: Backend> {
17 pub(crate) data: VecZnxDft<D, B>,
18 pub(crate) base2k: Base2K,
19 pub(crate) k: TorusPrecision,
20 pub(crate) dist: Distribution,
21}
22
23impl<D: Data, B: Backend> LWEInfos for GLWEPublicKeyPrepared<D, B> {
24 fn base2k(&self) -> Base2K {
25 self.base2k
26 }
27
28 fn k(&self) -> TorusPrecision {
29 self.k
30 }
31
32 fn size(&self) -> usize {
33 self.data.size()
34 }
35
36 fn n(&self) -> Degree {
37 Degree(self.data.n() as u32)
38 }
39}
40
41impl<D: Data, B: Backend> GLWEInfos for GLWEPublicKeyPrepared<D, B> {
42 fn rank(&self) -> Rank {
43 Rank(self.data.cols() as u32 - 1)
44 }
45}
46
47pub struct GLWEPublicKeyPreparedBuilder<D: Data, B: Backend> {
48 data: Option<VecZnxDft<D, B>>,
49 base2k: Option<Base2K>,
50 k: Option<TorusPrecision>,
51}
52
53impl<D: Data, B: Backend> GLWEPublicKeyPrepared<D, B> {
54 #[inline]
55 pub fn builder() -> GLWEPublicKeyPreparedBuilder<D, B> {
56 GLWEPublicKeyPreparedBuilder {
57 data: None,
58 base2k: None,
59 k: None,
60 }
61 }
62}
63
64impl<B: Backend> GLWEPublicKeyPreparedBuilder<Vec<u8>, B> {
65 #[inline]
66 pub fn layout<A>(mut self, layout: &A) -> Self
67 where
68 A: GLWEInfos,
69 B: VecZnxDftAllocBytesImpl<B>,
70 {
71 self.data = Some(VecZnxDft::alloc(
72 layout.n().into(),
73 (layout.rank() + 1).into(),
74 layout.size(),
75 ));
76 self.base2k = Some(layout.base2k());
77 self.k = Some(layout.k());
78 self
79 }
80}
81
82impl<D: Data, B: Backend> GLWEPublicKeyPreparedBuilder<D, B> {
83 #[inline]
84 pub fn data(mut self, data: VecZnxDft<D, B>) -> Self {
85 self.data = Some(data);
86 self
87 }
88 #[inline]
89 pub fn base2k(mut self, base2k: Base2K) -> Self {
90 self.base2k = Some(base2k);
91 self
92 }
93 #[inline]
94 pub fn k(mut self, k: TorusPrecision) -> Self {
95 self.k = Some(k);
96 self
97 }
98
99 pub fn build(self) -> Result<GLWEPublicKeyPrepared<D, B>, BuildError> {
100 let data: VecZnxDft<D, B> = self.data.ok_or(BuildError::MissingData)?;
101 let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?;
102 let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?;
103
104 if base2k == 0_u32 {
105 return Err(BuildError::ZeroBase2K);
106 }
107
108 if k == 0_u32 {
109 return Err(BuildError::ZeroTorusPrecision);
110 }
111
112 if data.n() == 0 {
113 return Err(BuildError::ZeroDegree);
114 }
115
116 if data.cols() == 0 {
117 return Err(BuildError::ZeroCols);
118 }
119
120 if data.size() == 0 {
121 return Err(BuildError::ZeroLimbs);
122 }
123
124 Ok(GLWEPublicKeyPrepared {
125 data,
126 base2k,
127 k,
128 dist: Distribution::NONE,
129 })
130 }
131}
132
133impl<B: Backend> GLWEPublicKeyPrepared<Vec<u8>, B> {
134 pub fn alloc<A>(module: &Module<B>, infos: &A) -> Self
135 where
136 A: GLWEInfos,
137 Module<B>: VecZnxDftAlloc<B>,
138 {
139 debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()");
140 Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank())
141 }
142
143 pub fn alloc_with(module: &Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self
144 where
145 Module<B>: VecZnxDftAlloc<B>,
146 {
147 Self {
148 data: module.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize),
149 base2k,
150 k,
151 dist: Distribution::NONE,
152 }
153 }
154
155 pub fn alloc_bytes<A>(module: &Module<B>, infos: &A) -> usize
156 where
157 A: GLWEInfos,
158 Module<B>: VecZnxDftAllocBytes,
159 {
160 debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()");
161 Self::alloc_bytes_with(module, infos.base2k(), infos.k(), infos.rank())
162 }
163
164 pub fn alloc_bytes_with(module: &Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize
165 where
166 Module<B>: VecZnxDftAllocBytes,
167 {
168 module.vec_znx_dft_alloc_bytes((rank + 1).into(), k.0.div_ceil(base2k.0) as usize)
169 }
170}
171
172impl<D: DataRef, B: Backend> PrepareAlloc<B, GLWEPublicKeyPrepared<Vec<u8>, B>> for GLWEPublicKey<D>
173where
174 Module<B>: VecZnxDftAlloc<B> + VecZnxDftApply<B>,
175{
176 fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GLWEPublicKeyPrepared<Vec<u8>, B> {
177 let mut pk_prepared: GLWEPublicKeyPrepared<Vec<u8>, B> = GLWEPublicKeyPrepared::alloc(module, self);
178 pk_prepared.prepare(module, self, scratch);
179 pk_prepared
180 }
181}
182
183impl<DM: DataMut, DR: DataRef, B: Backend> Prepare<B, GLWEPublicKey<DR>> for GLWEPublicKeyPrepared<DM, B>
184where
185 Module<B>: VecZnxDftApply<B>,
186{
187 fn prepare(&mut self, module: &Module<B>, other: &GLWEPublicKey<DR>, _scratch: &mut Scratch<B>) {
188 #[cfg(debug_assertions)]
189 {
190 assert_eq!(self.n(), other.n());
191 assert_eq!(self.size(), other.size());
192 }
193
194 (0..(self.rank() + 1).into()).for_each(|i| {
195 module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i);
196 });
197 self.k = other.k();
198 self.base2k = other.base2k();
199 self.dist = other.dist;
200 }
201}