poulpy_core/conversion/
gglwe_to_ggsw.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
4        VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
5    },
6    layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef},
7};
8
9use crate::{
10    GGLWEProduct, GLWECopy, ScratchTakeCore,
11    layouts::{
12        GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE,
13        GLWEInfos, LWEInfos,
14    },
15};
16
17impl GGLWE<Vec<u8>> {
18    pub fn from_gglw_tmp_bytes<R, A, M, BE: Backend>(module: &M, res_infos: &R, tsk_infos: &A) -> usize
19    where
20        M: GGSWFromGGLWE<BE>,
21        R: GGSWInfos,
22        A: GGLWEInfos,
23    {
24        module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos)
25    }
26}
27
28impl<D: DataMut> GGSW<D> {
29    pub fn from_gglwe<G, M, T, BE: Backend>(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch<BE>)
30    where
31        M: GGSWFromGGLWE<BE>,
32        G: GGLWEToRef,
33        T: GGLWEToGGSWKeyPreparedToRef<BE>,
34        Scratch<BE>: ScratchTakeCore<BE>,
35    {
36        module.ggsw_from_gglwe(self, gglwe, tsk, scratch);
37    }
38}
39
40impl<BE: Backend> GGSWFromGGLWE<BE> for Module<BE>
41where
42    Self: GGSWExpandRows<BE> + GLWECopy,
43{
44    fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
45    where
46        R: GGSWInfos,
47        A: GGLWEInfos,
48    {
49        self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)
50    }
51
52    fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
53    where
54        R: GGSWToMut,
55        A: GGLWEToRef,
56        T: GGLWEToGGSWKeyPreparedToRef<BE>,
57        Scratch<BE>: ScratchTakeCore<BE>,
58    {
59        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
60        let a: &GGLWE<&[u8]> = &a.to_ref();
61        let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
62
63        assert_eq!(res.rank(), a.rank_out());
64        assert_eq!(res.dnum(), a.dnum());
65        assert_eq!(res.n(), self.n() as u32);
66        assert_eq!(a.n(), self.n() as u32);
67        assert_eq!(tsk.n(), self.n() as u32);
68        assert_eq!(res.base2k(), a.base2k());
69
70        for row in 0..res.dnum().into() {
71            self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0));
72        }
73
74        self.ggsw_expand_row(res, tsk, scratch);
75    }
76}
77
78pub trait GGSWFromGGLWE<BE: Backend> {
79    fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
80    where
81        R: GGSWInfos,
82        A: GGLWEInfos;
83
84    fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
85    where
86        R: GGSWToMut,
87        A: GGLWEToRef,
88        T: GGLWEToGGSWKeyPreparedToRef<BE>,
89        Scratch<BE>: ScratchTakeCore<BE>;
90}
91
92pub trait GGSWExpandRows<BE: Backend> {
93    fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
94    where
95        R: GGSWInfos,
96        A: GGLWEInfos;
97
98    fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
99    where
100        R: GGSWToMut,
101        T: GGLWEToGGSWKeyPreparedToRef<BE>,
102        Scratch<BE>: ScratchTakeCore<BE>;
103}
104
105impl<BE: Backend> GGSWExpandRows<BE> for Module<BE>
106where
107    Self: GGLWEProduct<BE>
108        + VecZnxBigNormalize<BE>
109        + VecZnxBigNormalizeTmpBytes
110        + VecZnxBigBytesOf
111        + VecZnxDftBytesOf
112        + VecZnxDftApply<BE>
113        + VecZnxNormalize<BE>
114        + VecZnxBigAddSmallInplace<BE>
115        + VecZnxIdftApplyConsume<BE>
116        + VecZnxCopy,
117{
118    fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
119    where
120        R: GGSWInfos,
121        A: GGLWEInfos,
122    {
123        let base2k_tsk: usize = tsk_infos.base2k().into();
124
125        let rank: usize = res_infos.rank().into();
126        let cols: usize = rank + 1;
127
128        let res_size: usize = res_infos.size();
129        let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk);
130
131        let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
132        let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
133        let res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size);
134        let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos);
135        let normalize: usize = self.vec_znx_big_normalize_tmp_bytes();
136
137        (a_0 + a_dft + res_dft + gglwe_prod).max(normalize)
138    }
139
140    fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
141    where
142        R: GGSWToMut,
143        T: GGLWEToGGSWKeyPreparedToRef<BE>,
144        Scratch<BE>: ScratchTakeCore<BE>,
145    {
146        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
147        let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
148
149        let base2k_res: usize = res.base2k().into();
150        let base2k_tsk: usize = tsk.base2k().into();
151
152        assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
153
154        let rank: usize = res.rank().into();
155        let cols: usize = rank + 1;
156
157        let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk);
158
159        let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
160        let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
161
162        // Keyswitch the j-th row of the col 0
163        for row in 0..res.dnum().as_usize() {
164            let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
165
166            if base2k_res == base2k_tsk {
167                for col_i in 0..cols - 1 {
168                    self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
169                }
170                self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
171            } else {
172                for i in 0..cols - 1 {
173                    self.vec_znx_normalize(
174                        base2k_tsk,
175                        &mut a_0,
176                        0,
177                        base2k_res,
178                        glwe_mi_1.data(),
179                        i + 1,
180                        scratch_2,
181                    );
182                    self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
183                }
184                self.vec_znx_normalize(
185                    base2k_tsk,
186                    &mut a_0,
187                    0,
188                    base2k_res,
189                    glwe_mi_1.data(),
190                    0,
191                    scratch_2,
192                );
193            }
194
195            ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
196        }
197    }
198}
199
200fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
201    module: &M,
202    row: usize,
203    res: &mut R,
204    a_0: &C,
205    a_dft: &A,
206    tsk: &T,
207    scratch: &mut Scratch<BE>,
208) where
209    R: GGSWToMut,
210    C: VecZnxToRef,
211    A: VecZnxDftToRef<BE>,
212    M: GGLWEProduct<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigAddSmallInplace<BE> + VecZnxBigNormalize<BE>,
213    T: GGLWEToGGSWKeyPreparedToRef<BE>,
214    Scratch<BE>: ScratchTakeCore<BE>,
215{
216    let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
217    let a_0: &VecZnx<&[u8]> = &a_0.to_ref();
218    let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref();
219    let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
220    let cols: usize = res.rank().as_usize() + 1;
221
222    // Example for rank 3:
223    //
224    // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
225    // actually composed of that many dnum and we focus on a specific row here
226    // implicitely given ci_dft.
227    //
228    // # Input
229    //
230    // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0    , a1    , a2    )
231    // col 1: (0, 0, 0, 0)
232    // col 2: (0, 0, 0, 0)
233    // col 3: (0, 0, 0, 0)
234    //
235    // # Output
236    //
237    // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0       , a1       , a2       )
238    // col 1: (-(b0s0 + b1s1 + b2s2)       , b0 + M[i], b1       , b2       )
239    // col 2: (-(c0s0 + c1s1 + c2s2)       , c0       , c1 + M[i], c2       )
240    // col 3: (-(d0s0 + d1s1 + d2s2)       , d0       , d1       , d2 + M[i])
241    for col in 1..cols {
242        let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); // Todo optimise
243
244        // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
245        //
246        // # Example for col=1
247        //
248        // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2)
249        // +
250        // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2)
251        // +
252        // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
253        // =
254        // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
255        module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1);
256
257        let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft);
258
259        // Adds -(sum a[i] * s[i]) + m)  on the i-th column of tmp_idft_i
260        //
261        // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2)
262        // +
263        // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0)
264        // =
265        // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
266        // =
267        // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
268        module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
269
270        for j in 0..cols {
271            module.vec_znx_big_normalize(
272                res.base2k().as_usize(),
273                res.at_mut(row, col).data_mut(),
274                j,
275                tsk.base2k().as_usize(),
276                &res_big,
277                j,
278                scratch_1,
279            );
280        }
281    }
282}