poulpy_core/layouts/prepared/
ggsw_ct.rs1use poulpy_hal::{
2 api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare},
3 layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat},
4};
5
6use crate::layouts::{
7 GGSWCiphertext, Infos,
8 prepared::{Prepare, PrepareAlloc},
9};
10
11#[derive(PartialEq, Eq)]
12pub struct GGSWCiphertextPrepared<D: Data, B: Backend> {
13 pub(crate) data: VmpPMat<D, B>,
14 pub(crate) basek: usize,
15 pub(crate) k: usize,
16 pub(crate) digits: usize,
17}
18
19impl<B: Backend> GGSWCiphertextPrepared<Vec<u8>, B> {
20 pub fn alloc(module: &Module<B>, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self
21 where
22 Module<B>: VmpPMatAlloc<B>,
23 {
24 let size: usize = k.div_ceil(basek);
25 debug_assert!(digits > 0, "invalid ggsw: `digits` == 0");
26
27 debug_assert!(
28 size > digits,
29 "invalid ggsw: ceil(k/basek): {} <= digits: {}",
30 size,
31 digits
32 );
33
34 assert!(
35 rows * digits <= size,
36 "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
37 rows,
38 digits,
39 size
40 );
41
42 Self {
43 data: module.vmp_pmat_alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)),
44 basek,
45 k,
46 digits,
47 }
48 }
49
50 pub fn bytes_of(module: &Module<B>, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize
51 where
52 Module<B>: VmpPMatAllocBytes,
53 {
54 let size: usize = k.div_ceil(basek);
55 debug_assert!(
56 size > digits,
57 "invalid ggsw: ceil(k/basek): {} <= digits: {}",
58 size,
59 digits
60 );
61
62 assert!(
63 rows * digits <= size,
64 "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
65 rows,
66 digits,
67 size
68 );
69
70 module.vmp_pmat_alloc_bytes(n, rows, rank + 1, rank + 1, size)
71 }
72}
73
74impl<D: Data, B: Backend> Infos for GGSWCiphertextPrepared<D, B> {
75 type Inner = VmpPMat<D, B>;
76
77 fn inner(&self) -> &Self::Inner {
78 &self.data
79 }
80
81 fn basek(&self) -> usize {
82 self.basek
83 }
84
85 fn k(&self) -> usize {
86 self.k
87 }
88}
89
90impl<D: Data, B: Backend> GGSWCiphertextPrepared<D, B> {
91 pub fn rank(&self) -> usize {
92 self.data.cols_out() - 1
93 }
94
95 pub fn digits(&self) -> usize {
96 self.digits
97 }
98}
99
100impl<D: DataRef, B: Backend> GGSWCiphertextPrepared<D, B> {
101 pub fn data(&self) -> &VmpPMat<D, B> {
102 &self.data
103 }
104}
105
106impl<D: DataMut, DR: DataRef, B: Backend> Prepare<B, GGSWCiphertext<DR>> for GGSWCiphertextPrepared<D, B>
107where
108 Module<B>: VmpPrepare<B>,
109{
110 fn prepare(&mut self, module: &Module<B>, other: &GGSWCiphertext<DR>, scratch: &mut Scratch<B>) {
111 module.vmp_prepare(&mut self.data, &other.data, scratch);
112 self.k = other.k;
113 self.basek = other.basek;
114 self.digits = other.digits;
115 }
116}
117
118impl<D: DataRef, B: Backend> PrepareAlloc<B, GGSWCiphertextPrepared<Vec<u8>, B>> for GGSWCiphertext<D>
119where
120 Module<B>: VmpPMatAlloc<B> + VmpPrepare<B>,
121{
122 fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> GGSWCiphertextPrepared<Vec<u8>, B> {
123 let mut ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, B> = GGSWCiphertextPrepared::alloc(
124 module,
125 self.n(),
126 self.basek(),
127 self.k(),
128 self.rows(),
129 self.digits(),
130 self.rank(),
131 );
132 ggsw_prepared.prepare(module, self, scratch);
133 ggsw_prepared
134 }
135}