poulpy_core/layouts/prepared/
glwe_pk.rs1use poulpy_hal::{
2 api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply},
3 layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft},
4};
5
6use crate::{
7 dist::Distribution,
8 layouts::{
9 GLWEPublicKey, Infos,
10 prepared::{Prepare, PrepareAlloc},
11 },
12};
13
14#[derive(PartialEq, Eq)]
15pub struct GLWEPublicKeyPrepared<D: Data, B: Backend> {
16 pub(crate) data: VecZnxDft<D, B>,
17 pub(crate) basek: usize,
18 pub(crate) k: usize,
19 pub(crate) dist: Distribution,
20}
21
22impl<D: Data, B: Backend> Infos for GLWEPublicKeyPrepared<D, B> {
23 type Inner = VecZnxDft<D, B>;
24
25 fn inner(&self) -> &Self::Inner {
26 &self.data
27 }
28
29 fn basek(&self) -> usize {
30 self.basek
31 }
32
33 fn k(&self) -> usize {
34 self.k
35 }
36}
37
38impl<D: Data, B: Backend> GLWEPublicKeyPrepared<D, B> {
39 pub fn rank(&self) -> usize {
40 self.cols() - 1
41 }
42}
43
44impl<B: Backend> GLWEPublicKeyPrepared<Vec<u8>, B> {
45 pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
46 where
47 Module<B>: VecZnxDftAlloc<B>,
48 {
49 Self {
50 data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)),
51 basek,
52 k,
53 dist: Distribution::NONE,
54 }
55 }
56
57 pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
58 where
59 Module<B>: VecZnxDftAllocBytes,
60 {
61 module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek))
62 }
63}
64
65impl<D: DataRef, B: Backend> PrepareAlloc<B, GLWEPublicKeyPrepared<Vec<u8>, B>> for GLWEPublicKey<D>
66where
67 Module<B>: VecZnxDftAlloc<B> + VecZnxDftApply<B>,
68{
69 fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GLWEPublicKeyPrepared<Vec<u8>, B> {
70 let mut pk_prepared: GLWEPublicKeyPrepared<Vec<u8>, B> =
71 GLWEPublicKeyPrepared::alloc(module, self.basek(), self.k(), self.rank());
72 pk_prepared.prepare(module, self, scratch);
73 pk_prepared
74 }
75}
76
77impl<DM: DataMut, DR: DataRef, B: Backend> Prepare<B, GLWEPublicKey<DR>> for GLWEPublicKeyPrepared<DM, B>
78where
79 Module<B>: VecZnxDftApply<B>,
80{
81 fn prepare(&mut self, module: &Module<B>, other: &GLWEPublicKey<DR>, _scratch: &mut Scratch<B>) {
82 #[cfg(debug_assertions)]
83 {
84 assert_eq!(self.n(), other.n());
85 assert_eq!(self.size(), other.size());
86 }
87
88 (0..self.cols()).for_each(|i| {
89 module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i);
90 });
91 self.k = other.k;
92 self.basek = other.basek;
93 self.dist = other.dist;
94 }
95}