poulpy_core/layouts/prepared/
glwe_pk.rs

1use poulpy_hal::{
2    api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply},
3    layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos},
4    oep::VecZnxDftAllocBytesImpl,
5};
6
7use crate::{
8    dist::Distribution,
9    layouts::{
10        Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision,
11        prepared::{Prepare, PrepareAlloc},
12    },
13};
14
15#[derive(PartialEq, Eq)]
16pub struct GLWEPublicKeyPrepared<D: Data, B: Backend> {
17    pub(crate) data: VecZnxDft<D, B>,
18    pub(crate) base2k: Base2K,
19    pub(crate) k: TorusPrecision,
20    pub(crate) dist: Distribution,
21}
22
23impl<D: Data, B: Backend> LWEInfos for GLWEPublicKeyPrepared<D, B> {
24    fn base2k(&self) -> Base2K {
25        self.base2k
26    }
27
28    fn k(&self) -> TorusPrecision {
29        self.k
30    }
31
32    fn size(&self) -> usize {
33        self.data.size()
34    }
35
36    fn n(&self) -> Degree {
37        Degree(self.data.n() as u32)
38    }
39}
40
41impl<D: Data, B: Backend> GLWEInfos for GLWEPublicKeyPrepared<D, B> {
42    fn rank(&self) -> Rank {
43        Rank(self.data.cols() as u32 - 1)
44    }
45}
46
47pub struct GLWEPublicKeyPreparedBuilder<D: Data, B: Backend> {
48    data: Option<VecZnxDft<D, B>>,
49    base2k: Option<Base2K>,
50    k: Option<TorusPrecision>,
51}
52
53impl<D: Data, B: Backend> GLWEPublicKeyPrepared<D, B> {
54    #[inline]
55    pub fn builder() -> GLWEPublicKeyPreparedBuilder<D, B> {
56        GLWEPublicKeyPreparedBuilder {
57            data: None,
58            base2k: None,
59            k: None,
60        }
61    }
62}
63
64impl<B: Backend> GLWEPublicKeyPreparedBuilder<Vec<u8>, B> {
65    #[inline]
66    pub fn layout<A>(mut self, layout: &A) -> Self
67    where
68        A: GLWEInfos,
69        B: VecZnxDftAllocBytesImpl<B>,
70    {
71        self.data = Some(VecZnxDft::alloc(
72            layout.n().into(),
73            (layout.rank() + 1).into(),
74            layout.size(),
75        ));
76        self.base2k = Some(layout.base2k());
77        self.k = Some(layout.k());
78        self
79    }
80}
81
82impl<D: Data, B: Backend> GLWEPublicKeyPreparedBuilder<D, B> {
83    #[inline]
84    pub fn data(mut self, data: VecZnxDft<D, B>) -> Self {
85        self.data = Some(data);
86        self
87    }
88    #[inline]
89    pub fn base2k(mut self, base2k: Base2K) -> Self {
90        self.base2k = Some(base2k);
91        self
92    }
93    #[inline]
94    pub fn k(mut self, k: TorusPrecision) -> Self {
95        self.k = Some(k);
96        self
97    }
98
99    pub fn build(self) -> Result<GLWEPublicKeyPrepared<D, B>, BuildError> {
100        let data: VecZnxDft<D, B> = self.data.ok_or(BuildError::MissingData)?;
101        let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?;
102        let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?;
103
104        if base2k == 0_u32 {
105            return Err(BuildError::ZeroBase2K);
106        }
107
108        if k == 0_u32 {
109            return Err(BuildError::ZeroTorusPrecision);
110        }
111
112        if data.n() == 0 {
113            return Err(BuildError::ZeroDegree);
114        }
115
116        if data.cols() == 0 {
117            return Err(BuildError::ZeroCols);
118        }
119
120        if data.size() == 0 {
121            return Err(BuildError::ZeroLimbs);
122        }
123
124        Ok(GLWEPublicKeyPrepared {
125            data,
126            base2k,
127            k,
128            dist: Distribution::NONE,
129        })
130    }
131}
132
133impl<B: Backend> GLWEPublicKeyPrepared<Vec<u8>, B> {
134    pub fn alloc<A>(module: &Module<B>, infos: &A) -> Self
135    where
136        A: GLWEInfos,
137        Module<B>: VecZnxDftAlloc<B>,
138    {
139        debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()");
140        Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank())
141    }
142
143    pub fn alloc_with(module: &Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self
144    where
145        Module<B>: VecZnxDftAlloc<B>,
146    {
147        Self {
148            data: module.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize),
149            base2k,
150            k,
151            dist: Distribution::NONE,
152        }
153    }
154
155    pub fn alloc_bytes<A>(module: &Module<B>, infos: &A) -> usize
156    where
157        A: GLWEInfos,
158        Module<B>: VecZnxDftAllocBytes,
159    {
160        debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()");
161        Self::alloc_bytes_with(module, infos.base2k(), infos.k(), infos.rank())
162    }
163
164    pub fn alloc_bytes_with(module: &Module<B>, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize
165    where
166        Module<B>: VecZnxDftAllocBytes,
167    {
168        module.vec_znx_dft_alloc_bytes((rank + 1).into(), k.0.div_ceil(base2k.0) as usize)
169    }
170}
171
172impl<D: DataRef, B: Backend> PrepareAlloc<B, GLWEPublicKeyPrepared<Vec<u8>, B>> for GLWEPublicKey<D>
173where
174    Module<B>: VecZnxDftAlloc<B> + VecZnxDftApply<B>,
175{
176    fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GLWEPublicKeyPrepared<Vec<u8>, B> {
177        let mut pk_prepared: GLWEPublicKeyPrepared<Vec<u8>, B> = GLWEPublicKeyPrepared::alloc(module, self);
178        pk_prepared.prepare(module, self, scratch);
179        pk_prepared
180    }
181}
182
183impl<DM: DataMut, DR: DataRef, B: Backend> Prepare<B, GLWEPublicKey<DR>> for GLWEPublicKeyPrepared<DM, B>
184where
185    Module<B>: VecZnxDftApply<B>,
186{
187    fn prepare(&mut self, module: &Module<B>, other: &GLWEPublicKey<DR>, _scratch: &mut Scratch<B>) {
188        #[cfg(debug_assertions)]
189        {
190            assert_eq!(self.n(), other.n());
191            assert_eq!(self.size(), other.size());
192        }
193
194        (0..(self.rank() + 1).into()).for_each(|i| {
195            module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i);
196        });
197        self.k = other.k();
198        self.base2k = other.base2k();
199        self.dist = other.dist;
200    }
201}