poulpy_core/layouts/compressed/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{VecZnxCopy, VecZnxFillUniform},
3    layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo},
4    source::Source,
5};
6
7use crate::layouts::{
8    Base2K, Degree, Digits, GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
9    compressed::{Decompress, GGLWESwitchingKeyCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGLWETensorKeyCompressed<D: Data> {
16    pub(crate) keys: Vec<GGLWESwitchingKeyCompressed<D>>,
17}
18
19impl<D: Data> LWEInfos for GGLWETensorKeyCompressed<D> {
20    fn n(&self) -> Degree {
21        self.keys[0].n()
22    }
23
24    fn base2k(&self) -> Base2K {
25        self.keys[0].base2k()
26    }
27
28    fn k(&self) -> TorusPrecision {
29        self.keys[0].k()
30    }
31    fn size(&self) -> usize {
32        self.keys[0].size()
33    }
34}
35impl<D: Data> GLWEInfos for GGLWETensorKeyCompressed<D> {
36    fn rank(&self) -> Rank {
37        self.rank_out()
38    }
39}
40
41impl<D: Data> GGLWELayoutInfos for GGLWETensorKeyCompressed<D> {
42    fn rank_in(&self) -> Rank {
43        self.rank_out()
44    }
45
46    fn rank_out(&self) -> Rank {
47        self.keys[0].rank_out()
48    }
49
50    fn digits(&self) -> Digits {
51        self.keys[0].digits()
52    }
53
54    fn rows(&self) -> Rows {
55        self.keys[0].rows()
56    }
57}
58
59impl<D: DataRef> fmt::Debug for GGLWETensorKeyCompressed<D> {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        write!(f, "{self}")
62    }
63}
64
65impl<D: DataMut> FillUniform for GGLWETensorKeyCompressed<D> {
66    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
67        self.keys
68            .iter_mut()
69            .for_each(|key: &mut GGLWESwitchingKeyCompressed<D>| key.fill_uniform(log_bound, source))
70    }
71}
72
73impl<D: DataRef> fmt::Display for GGLWETensorKeyCompressed<D> {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        writeln!(f, "(GLWETensorKeyCompressed)",)?;
76        for (i, key) in self.keys.iter().enumerate() {
77            write!(f, "{i}: {key}")?;
78        }
79        Ok(())
80    }
81}
82
83impl GGLWETensorKeyCompressed<Vec<u8>> {
84    pub fn alloc<A>(infos: &A) -> Self
85    where
86        A: GGLWELayoutInfos,
87    {
88        assert_eq!(
89            infos.rank_in(),
90            infos.rank_out(),
91            "rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
92        );
93        Self::alloc_with(
94            infos.n(),
95            infos.base2k(),
96            infos.k(),
97            infos.rows(),
98            infos.digits(),
99            infos.rank_out(),
100        )
101    }
102
103    pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
104        let mut keys: Vec<GGLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
105        let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
106        (0..pairs).for_each(|_| {
107            keys.push(GGLWESwitchingKeyCompressed::alloc_with(
108                n,
109                base2k,
110                k,
111                rows,
112                digits,
113                Rank(1),
114                rank,
115            ));
116        });
117        Self { keys }
118    }
119
120    pub fn alloc_bytes<A>(infos: &A) -> usize
121    where
122        A: GGLWELayoutInfos,
123    {
124        assert_eq!(
125            infos.rank_in(),
126            infos.rank_out(),
127            "rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
128        );
129        let rank_out: usize = infos.rank_out().into();
130        let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
131        pairs
132            * GGLWESwitchingKeyCompressed::alloc_bytes_with(
133                infos.n(),
134                infos.base2k(),
135                infos.k(),
136                infos.rows(),
137                infos.digits(),
138                Rank(1),
139                infos.rank_out(),
140            )
141    }
142
143    pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
144        let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
145        pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, digits, Rank(1), rank)
146    }
147}
148
149impl<D: DataMut> ReaderFrom for GGLWETensorKeyCompressed<D> {
150    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
151        let len: usize = reader.read_u64::<LittleEndian>()? as usize;
152        if self.keys.len() != len {
153            return Err(std::io::Error::new(
154                std::io::ErrorKind::InvalidData,
155                format!("self.keys.len()={} != read len={}", self.keys.len(), len),
156            ));
157        }
158        for key in &mut self.keys {
159            key.read_from(reader)?;
160        }
161        Ok(())
162    }
163}
164
165impl<D: DataRef> WriterTo for GGLWETensorKeyCompressed<D> {
166    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
167        writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
168        for key in &self.keys {
169            key.write_to(writer)?;
170        }
171        Ok(())
172    }
173}
174
175impl<D: DataMut> GGLWETensorKeyCompressed<D> {
176    pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed<D> {
177        if i > j {
178            std::mem::swap(&mut i, &mut j);
179        };
180        let rank: usize = self.rank_out().into();
181        &mut self.keys[i * rank + j - (i * (i + 1) / 2)]
182    }
183}
184
185impl<D: DataMut, DR: DataRef, B: Backend> Decompress<B, GGLWETensorKeyCompressed<DR>> for GGLWETensorKey<D>
186where
187    Module<B>: VecZnxFillUniform + VecZnxCopy,
188{
189    fn decompress(&mut self, module: &Module<B>, other: &GGLWETensorKeyCompressed<DR>) {
190        #[cfg(debug_assertions)]
191        {
192            assert_eq!(
193                self.keys.len(),
194                other.keys.len(),
195                "invalid receiver: self.keys.len()={} != other.keys.len()={}",
196                self.keys.len(),
197                other.keys.len()
198            );
199        }
200
201        self.keys
202            .iter_mut()
203            .zip(other.keys.iter())
204            .for_each(|(a, b)| {
205                a.decompress(module, b);
206            });
207    }
208}