poulpy_core/keyswitching/
ggsw.rs

1use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx};
2
3use crate::{
4    GGSWExpandRows, ScratchTakeCore,
5    keyswitching::GLWEKeyswitch,
6    layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef},
7};
8
9impl GGSW<Vec<u8>> {
10    pub fn keyswitch_tmp_bytes<R, A, K, T, M, BE: Backend>(
11        module: &M,
12        res_infos: &R,
13        a_infos: &A,
14        key_infos: &K,
15        tsk_infos: &T,
16    ) -> usize
17    where
18        R: GGSWInfos,
19        A: GGSWInfos,
20        K: GGLWEInfos,
21        T: GGLWEInfos,
22        M: GGSWKeyswitch<BE>,
23    {
24        module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos)
25    }
26}
27
28impl<D: DataMut> GGSW<D> {
29    pub fn keyswitch<M, A, K, T, BE: Backend>(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
30    where
31        A: GGSWToRef,
32        K: GGLWEPreparedToRef<BE>,
33        T: GLWETensorKeyPreparedToRef<BE>,
34        Scratch<BE>: ScratchTakeCore<BE>,
35        M: GGSWKeyswitch<BE>,
36    {
37        module.ggsw_keyswitch(self, a, key, tsk, scratch);
38    }
39
40    pub fn keyswitch_inplace<M, K, T, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
41    where
42        K: GGLWEPreparedToRef<BE>,
43        T: GLWETensorKeyPreparedToRef<BE>,
44        Scratch<BE>: ScratchTakeCore<BE>,
45        M: GGSWKeyswitch<BE>,
46    {
47        module.ggsw_keyswitch_inplace(self, key, tsk, scratch);
48    }
49}
50
51impl<BE: Backend> GGSWKeyswitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE> {}
52
53pub trait GGSWKeyswitch<BE: Backend>
54where
55    Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
56{
57    fn ggsw_keyswitch_tmp_bytes<R, A, K, T>(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize
58    where
59        R: GGSWInfos,
60        A: GGSWInfos,
61        K: GGLWEInfos,
62        T: GGLWEInfos,
63    {
64        assert_eq!(key_infos.rank_in(), key_infos.rank_out());
65        assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out());
66        assert_eq!(key_infos.rank_in(), tsk_infos.rank_in());
67
68        let rank: usize = key_infos.rank_out().into();
69
70        let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize;
71        let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out);
72        let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
73        let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos);
74        let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos);
75        let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
76
77        if a_infos.base2k() == tsk_infos.base2k() {
78            res_znx + ci_dft + (ks | expand_rows | res_dft)
79        } else {
80            let a_conv: usize = VecZnx::bytes_of(
81                self.n(),
82                1,
83                res_infos.k().div_ceil(tsk_infos.base2k()) as usize,
84            ) + self.vec_znx_normalize_tmp_bytes();
85            res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft)
86        }
87    }
88
89    fn ggsw_keyswitch<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
90    where
91        R: GGSWToMut,
92        A: GGSWToRef,
93        K: GGLWEPreparedToRef<BE>,
94        T: GLWETensorKeyPreparedToRef<BE>,
95        Scratch<BE>: ScratchTakeCore<BE>,
96    {
97        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
98        let a: &GGSW<&[u8]> = &a.to_ref();
99
100        assert!(res.dnum() <= a.dnum());
101        assert_eq!(res.dsize(), a.dsize());
102
103        for row in 0..a.dnum().into() {
104            // Key-switch column 0, i.e.
105            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
106            self.glwe_keyswitch(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch);
107        }
108
109        self.ggsw_expand_row(res, tsk, scratch);
110    }
111
112    fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
113    where
114        R: GGSWToMut,
115        K: GGLWEPreparedToRef<BE>,
116        T: GLWETensorKeyPreparedToRef<BE>,
117        Scratch<BE>: ScratchTakeCore<BE>,
118    {
119        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
120
121        for row in 0..res.dnum().into() {
122            // Key-switch column 0, i.e.
123            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
124            self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch);
125        }
126
127        self.ggsw_expand_row(res, tsk, scratch);
128    }
129}