poulpy_core/layouts/prepared/
glwe_pk.rs

1use poulpy_hal::{
2    api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply},
3    layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft},
4};
5
6use crate::{
7    dist::Distribution,
8    layouts::{
9        GLWEPublicKey, Infos,
10        prepared::{Prepare, PrepareAlloc},
11    },
12};
13
14#[derive(PartialEq, Eq)]
15pub struct GLWEPublicKeyPrepared<D: Data, B: Backend> {
16    pub(crate) data: VecZnxDft<D, B>,
17    pub(crate) basek: usize,
18    pub(crate) k: usize,
19    pub(crate) dist: Distribution,
20}
21
22impl<D: Data, B: Backend> Infos for GLWEPublicKeyPrepared<D, B> {
23    type Inner = VecZnxDft<D, B>;
24
25    fn inner(&self) -> &Self::Inner {
26        &self.data
27    }
28
29    fn basek(&self) -> usize {
30        self.basek
31    }
32
33    fn k(&self) -> usize {
34        self.k
35    }
36}
37
38impl<D: Data, B: Backend> GLWEPublicKeyPrepared<D, B> {
39    pub fn rank(&self) -> usize {
40        self.cols() - 1
41    }
42}
43
44impl<B: Backend> GLWEPublicKeyPrepared<Vec<u8>, B> {
45    pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self
46    where
47        Module<B>: VecZnxDftAlloc<B>,
48    {
49        Self {
50            data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)),
51            basek,
52            k,
53            dist: Distribution::NONE,
54        }
55    }
56
57    pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
58    where
59        Module<B>: VecZnxDftAllocBytes,
60    {
61        module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek))
62    }
63}
64
65impl<D: DataRef, B: Backend> PrepareAlloc<B, GLWEPublicKeyPrepared<Vec<u8>, B>> for GLWEPublicKey<D>
66where
67    Module<B>: VecZnxDftAlloc<B> + VecZnxDftApply<B>,
68{
69    fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GLWEPublicKeyPrepared<Vec<u8>, B> {
70        let mut pk_prepared: GLWEPublicKeyPrepared<Vec<u8>, B> =
71            GLWEPublicKeyPrepared::alloc(module, self.basek(), self.k(), self.rank());
72        pk_prepared.prepare(module, self, scratch);
73        pk_prepared
74    }
75}
76
77impl<DM: DataMut, DR: DataRef, B: Backend> Prepare<B, GLWEPublicKey<DR>> for GLWEPublicKeyPrepared<DM, B>
78where
79    Module<B>: VecZnxDftApply<B>,
80{
81    fn prepare(&mut self, module: &Module<B>, other: &GLWEPublicKey<DR>, _scratch: &mut Scratch<B>) {
82        #[cfg(debug_assertions)]
83        {
84            assert_eq!(self.n(), other.n());
85            assert_eq!(self.size(), other.size());
86        }
87
88        (0..self.cols()).for_each(|i| {
89            module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i);
90        });
91        self.k = other.k;
92        self.basek = other.basek;
93        self.dist = other.dist;
94    }
95}