poulpy_core/layouts/prepared/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare},
3    layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat},
4};
5
6use crate::layouts::{
7    GGSWCiphertext, Infos,
8    prepared::{Prepare, PrepareAlloc},
9};
10
11#[derive(PartialEq, Eq)]
12pub struct GGSWCiphertextPrepared<D: Data, B: Backend> {
13    pub(crate) data: VmpPMat<D, B>,
14    pub(crate) basek: usize,
15    pub(crate) k: usize,
16    pub(crate) digits: usize,
17}
18
19impl<B: Backend> GGSWCiphertextPrepared<Vec<u8>, B> {
20    pub fn alloc(module: &Module<B>, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self
21    where
22        Module<B>: VmpPMatAlloc<B>,
23    {
24        let size: usize = k.div_ceil(basek);
25        debug_assert!(digits > 0, "invalid ggsw: `digits` == 0");
26
27        debug_assert!(
28            size > digits,
29            "invalid ggsw: ceil(k/basek): {} <= digits: {}",
30            size,
31            digits
32        );
33
34        assert!(
35            rows * digits <= size,
36            "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
37            rows,
38            digits,
39            size
40        );
41
42        Self {
43            data: module.vmp_pmat_alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)),
44            basek,
45            k,
46            digits,
47        }
48    }
49
50    pub fn bytes_of(module: &Module<B>, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize
51    where
52        Module<B>: VmpPMatAllocBytes,
53    {
54        let size: usize = k.div_ceil(basek);
55        debug_assert!(
56            size > digits,
57            "invalid ggsw: ceil(k/basek): {} <= digits: {}",
58            size,
59            digits
60        );
61
62        assert!(
63            rows * digits <= size,
64            "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
65            rows,
66            digits,
67            size
68        );
69
70        module.vmp_pmat_alloc_bytes(n, rows, rank + 1, rank + 1, size)
71    }
72}
73
74impl<D: Data, B: Backend> Infos for GGSWCiphertextPrepared<D, B> {
75    type Inner = VmpPMat<D, B>;
76
77    fn inner(&self) -> &Self::Inner {
78        &self.data
79    }
80
81    fn basek(&self) -> usize {
82        self.basek
83    }
84
85    fn k(&self) -> usize {
86        self.k
87    }
88}
89
90impl<D: Data, B: Backend> GGSWCiphertextPrepared<D, B> {
91    pub fn rank(&self) -> usize {
92        self.data.cols_out() - 1
93    }
94
95    pub fn digits(&self) -> usize {
96        self.digits
97    }
98}
99
100impl<D: DataRef, B: Backend> GGSWCiphertextPrepared<D, B> {
101    pub fn data(&self) -> &VmpPMat<D, B> {
102        &self.data
103    }
104}
105
106impl<D: DataMut, DR: DataRef, B: Backend> Prepare<B, GGSWCiphertext<DR>> for GGSWCiphertextPrepared<D, B>
107where
108    Module<B>: VmpPrepare<B>,
109{
110    fn prepare(&mut self, module: &Module<B>, other: &GGSWCiphertext<DR>, scratch: &mut Scratch<B>) {
111        module.vmp_prepare(&mut self.data, &other.data, scratch);
112        self.k = other.k;
113        self.basek = other.basek;
114        self.digits = other.digits;
115    }
116}
117
118impl<D: DataRef, B: Backend> PrepareAlloc<B, GGSWCiphertextPrepared<Vec<u8>, B>> for GGSWCiphertext<D>
119where
120    Module<B>: VmpPMatAlloc<B> + VmpPrepare<B>,
121{
122    fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GGSWCiphertextPrepared<Vec<u8>, B> {
123        let mut ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, B> = GGSWCiphertextPrepared::alloc(
124            module,
125            self.n(),
126            self.basek(),
127            self.k(),
128            self.rows(),
129            self.digits(),
130            self.rank(),
131        );
132        ggsw_prepared.prepare(module, self, scratch);
133        ggsw_prepared
134    }
135}