poulpy_core/external_product/
ggsw.rs1use poulpy_hal::{
2 api::{ModuleN, ScratchAvailable},
3 layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
4};
5
6use crate::{
7 GLWEExternalProduct, ScratchTakeCore,
8 layouts::{
9 GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos,
10 prepared::{GGSWPrepared, GGSWPreparedToRef},
11 },
12};
13
14pub trait GGSWExternalProduct<BE: Backend>
15where
16 Self: GLWEExternalProduct<BE> + ModuleN,
17{
18 fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
19 where
20 R: GGSWInfos,
21 A: GGSWInfos,
22 B: GGSWInfos,
23 {
24 self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
25 }
26
27 fn ggsw_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
28 where
29 R: GGSWToMut,
30 A: GGSWToRef,
31 B: GGSWPreparedToRef<BE>,
32 Scratch<BE>: ScratchTakeCore<BE>,
33 {
34 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
35 let a: &GGSW<&[u8]> = &a.to_ref();
36 let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
37
38 assert_eq!(
39 res.rank(),
40 a.rank(),
41 "res rank: {} != a rank: {}",
42 res.rank(),
43 a.rank()
44 );
45 assert_eq!(
46 res.rank(),
47 b.rank(),
48 "res rank: {} != b rank: {}",
49 res.rank(),
50 b.rank()
51 );
52
53 assert_eq!(res.base2k(), a.base2k());
54
55 assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b));
56
57 let min_dnum: usize = res.dnum().min(a.dnum()).into();
58
59 for row in 0..min_dnum {
60 for col in 0..(res.rank() + 1).into() {
61 self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch);
62 }
63 }
64
65 for row in min_dnum..res.dnum().into() {
66 for col in 0..(res.rank() + 1).into() {
67 res.at_mut(row, col).data.zero();
68 }
69 }
70 }
71
72 fn ggsw_external_product_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
73 where
74 R: GGSWToMut,
75 A: GGSWPreparedToRef<BE>,
76 Scratch<BE>: ScratchTakeCore<BE>,
77 {
78 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
79 let a: &GGSWPrepared<&[u8], BE> = &a.to_ref();
80
81 assert_eq!(res.n(), self.n() as u32);
82 assert_eq!(a.n(), self.n() as u32);
83 assert_eq!(
84 res.rank(),
85 a.rank(),
86 "res rank: {} != a rank: {}",
87 res.rank(),
88 a.rank()
89 );
90
91 for row in 0..res.dnum().into() {
92 for col in 0..(res.rank() + 1).into() {
93 self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch);
94 }
95 }
96 }
97}
98
99impl<BE: Backend> GGSWExternalProduct<BE> for Module<BE> where Self: GLWEExternalProduct<BE> {}
100
101impl GGSW<Vec<u8>> {
102 pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
103 where
104 R: GGSWInfos,
105 A: GGSWInfos,
106 B: GGSWInfos,
107 M: GGSWExternalProduct<BE>,
108 {
109 module.ggsw_external_product_tmp_bytes(res_infos, a_infos, b_infos)
110 }
111}
112
113impl<DataSelf: DataMut> GGSW<DataSelf> {
114 pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
115 where
116 M: GGSWExternalProduct<BE>,
117 A: GGSWToRef,
118 B: GGSWPreparedToRef<BE>,
119 Scratch<BE>: ScratchTakeCore<BE>,
120 {
121 module.ggsw_external_product(self, a, b, scratch);
122 }
123
124 pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
125 where
126 M: GGSWExternalProduct<BE>,
127 A: GGSWPreparedToRef<BE>,
128 Scratch<BE>: ScratchTakeCore<BE>,
129 {
130 module.ggsw_external_product_inplace(self, a, scratch);
131 }
132}