poulpy_core/external_product/
ggsw.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchAvailable},
3    layouts::{Backend, DataMut, Module, Scratch, ZnxZero},
4};
5
6use crate::{
7    GLWEExternalProduct, ScratchTakeCore,
8    layouts::{
9        GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos, LWEInfos,
10        prepared::{GGSWPrepared, GGSWPreparedToRef},
11    },
12};
13
14pub trait GGSWExternalProduct<BE: Backend>
15where
16    Self: GLWEExternalProduct<BE> + ModuleN,
17{
18    fn ggsw_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
19    where
20        R: GGSWInfos,
21        A: GGSWInfos,
22        B: GGSWInfos,
23    {
24        self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
25    }
26
27    fn ggsw_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
28    where
29        R: GGSWToMut,
30        A: GGSWToRef,
31        B: GGSWPreparedToRef<BE>,
32        Scratch<BE>: ScratchTakeCore<BE>,
33    {
34        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
35        let a: &GGSW<&[u8]> = &a.to_ref();
36        let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
37
38        assert_eq!(
39            res.rank(),
40            a.rank(),
41            "res rank: {} != a rank: {}",
42            res.rank(),
43            a.rank()
44        );
45        assert_eq!(
46            res.rank(),
47            b.rank(),
48            "res rank: {} != b rank: {}",
49            res.rank(),
50            b.rank()
51        );
52
53        assert_eq!(res.base2k(), a.base2k());
54
55        assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b));
56
57        let min_dnum: usize = res.dnum().min(a.dnum()).into();
58
59        for row in 0..min_dnum {
60            for col in 0..(res.rank() + 1).into() {
61                self.glwe_external_product(&mut res.at_mut(row, col), &a.at(row, col), b, scratch);
62            }
63        }
64
65        for row in min_dnum..res.dnum().into() {
66            for col in 0..(res.rank() + 1).into() {
67                res.at_mut(row, col).data.zero();
68            }
69        }
70    }
71
72    fn ggsw_external_product_inplace<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
73    where
74        R: GGSWToMut,
75        A: GGSWPreparedToRef<BE>,
76        Scratch<BE>: ScratchTakeCore<BE>,
77    {
78        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
79        let a: &GGSWPrepared<&[u8], BE> = &a.to_ref();
80
81        assert_eq!(res.n(), self.n() as u32);
82        assert_eq!(a.n(), self.n() as u32);
83        assert_eq!(
84            res.rank(),
85            a.rank(),
86            "res rank: {} != a rank: {}",
87            res.rank(),
88            a.rank()
89        );
90
91        for row in 0..res.dnum().into() {
92            for col in 0..(res.rank() + 1).into() {
93                self.glwe_external_product_inplace(&mut res.at_mut(row, col), a, scratch);
94            }
95        }
96    }
97}
98
99impl<BE: Backend> GGSWExternalProduct<BE> for Module<BE> where Self: GLWEExternalProduct<BE> {}
100
101impl GGSW<Vec<u8>> {
102    pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
103    where
104        R: GGSWInfos,
105        A: GGSWInfos,
106        B: GGSWInfos,
107        M: GGSWExternalProduct<BE>,
108    {
109        module.ggsw_external_product_tmp_bytes(res_infos, a_infos, b_infos)
110    }
111}
112
113impl<DataSelf: DataMut> GGSW<DataSelf> {
114    pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
115    where
116        M: GGSWExternalProduct<BE>,
117        A: GGSWToRef,
118        B: GGSWPreparedToRef<BE>,
119        Scratch<BE>: ScratchTakeCore<BE>,
120    {
121        module.ggsw_external_product(self, a, b, scratch);
122    }
123
124    pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
125    where
126        M: GGSWExternalProduct<BE>,
127        A: GGSWPreparedToRef<BE>,
128        Scratch<BE>: ScratchTakeCore<BE>,
129    {
130        module.ggsw_external_product_inplace(self, a, scratch);
131    }
132}