1use poulpy_hal::{
2 api::ModuleN,
3 layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
4};
5
6use crate::{
7 GLWEKeyswitch, ScratchTakeCore,
8 layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToRef, LWE, LWEInfos, LWEToMut, Rank},
9};
10
11pub trait LWESampleExtract
12where
13 Self: ModuleN,
14{
15 fn lwe_sample_extract<R, A>(&self, res: &mut R, a: &A)
16 where
17 R: LWEToMut,
18 A: GLWEToRef,
19 {
20 let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
21 let a: &GLWE<&[u8]> = &a.to_ref();
22
23 assert!(res.n() <= a.n());
24 assert_eq!(a.n(), self.n() as u32);
25 assert!(res.base2k() == a.base2k());
26
27 let min_size: usize = res.size().min(a.size());
28 let n: usize = res.n().into();
29
30 res.data.zero();
31 (0..min_size).for_each(|i| {
32 let data_lwe: &mut [i64] = res.data.at_mut(0, i);
33 data_lwe[0] = a.data.at(0, i)[0];
34 data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]);
35 });
36 }
37}
38
39impl<BE: Backend> LWESampleExtract for Module<BE> where Self: ModuleN {}
40impl<BE: Backend> LWEFromGLWE<BE> for Module<BE> where Self: GLWEKeyswitch<BE> + LWESampleExtract {}
41
42pub trait LWEFromGLWE<BE: Backend>
43where
44 Self: GLWEKeyswitch<BE> + LWESampleExtract,
45{
46 fn lwe_from_glwe_tmp_bytes<R, A, K>(&self, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize
47 where
48 R: LWEInfos,
49 A: GLWEInfos,
50 K: GGLWEInfos,
51 {
52 let res_infos: GLWELayout = GLWELayout {
53 n: self.n().into(),
54 base2k: lwe_infos.base2k(),
55 k: lwe_infos.k(),
56 rank: Rank(1),
57 };
58
59 GLWE::bytes_of(
60 self.n().into(),
61 lwe_infos.base2k(),
62 lwe_infos.k(),
63 1u32.into(),
64 ) + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos)
65 }
66
67 fn lwe_from_glwe<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
68 where
69 R: LWEToMut,
70 A: GLWEToRef,
71 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
72 Scratch<BE>: ScratchTakeCore<BE>,
73 {
74 let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
75 let a: &GLWE<&[u8]> = &a.to_ref();
76
77 assert_eq!(a.n(), self.n() as u32);
78 assert_eq!(key.n(), self.n() as u32);
79 assert!(res.n() <= self.n() as u32);
80
81 let glwe_layout: GLWELayout = GLWELayout {
82 n: self.n().into(),
83 base2k: res.base2k(),
84 k: res.k(),
85 rank: Rank(1),
86 };
87
88 let (mut tmp_glwe, scratch_1) = scratch.take_glwe(&glwe_layout);
89 self.glwe_keyswitch(&mut tmp_glwe, a, key, scratch_1);
90 self.lwe_sample_extract(res, &tmp_glwe);
91 }
92}
93
94impl LWE<Vec<u8>> {
95 pub fn from_glwe_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize
96 where
97 R: LWEInfos,
98 A: GLWEInfos,
99 K: GGLWEInfos,
100 M: LWEFromGLWE<BE>,
101 {
102 module.lwe_from_glwe_tmp_bytes(lwe_infos, glwe_infos, key_infos)
103 }
104}
105
106impl<D: DataMut> LWE<D> {
107 pub fn sample_extract<A, M>(&mut self, module: &M, a: &A)
108 where
109 A: GLWEToRef,
110 M: LWESampleExtract,
111 {
112 module.lwe_sample_extract(self, a);
113 }
114
115 pub fn from_glwe<A, K, M, BE: Backend>(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch<BE>)
116 where
117 A: GLWEToRef,
118 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
119 M: LWEFromGLWE<BE>,
120 Scratch<BE>: ScratchTakeCore<BE>,
121 {
122 module.lwe_from_glwe(self, a, key, scratch);
123 }
124}