poulpy_core/layouts/compressed/
gglwe_tsk.rs1use poulpy_hal::{
2 api::{VecZnxCopy, VecZnxFillUniform},
3 layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo},
4 source::Source,
5};
6
7use crate::layouts::{
8 Base2K, Degree, Digits, GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
9 compressed::{Decompress, GGLWESwitchingKeyCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGLWETensorKeyCompressed<D: Data> {
16 pub(crate) keys: Vec<GGLWESwitchingKeyCompressed<D>>,
17}
18
19impl<D: Data> LWEInfos for GGLWETensorKeyCompressed<D> {
20 fn n(&self) -> Degree {
21 self.keys[0].n()
22 }
23
24 fn base2k(&self) -> Base2K {
25 self.keys[0].base2k()
26 }
27
28 fn k(&self) -> TorusPrecision {
29 self.keys[0].k()
30 }
31 fn size(&self) -> usize {
32 self.keys[0].size()
33 }
34}
35impl<D: Data> GLWEInfos for GGLWETensorKeyCompressed<D> {
36 fn rank(&self) -> Rank {
37 self.rank_out()
38 }
39}
40
41impl<D: Data> GGLWELayoutInfos for GGLWETensorKeyCompressed<D> {
42 fn rank_in(&self) -> Rank {
43 self.rank_out()
44 }
45
46 fn rank_out(&self) -> Rank {
47 self.keys[0].rank_out()
48 }
49
50 fn digits(&self) -> Digits {
51 self.keys[0].digits()
52 }
53
54 fn rows(&self) -> Rows {
55 self.keys[0].rows()
56 }
57}
58
59impl<D: DataRef> fmt::Debug for GGLWETensorKeyCompressed<D> {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{self}")
62 }
63}
64
65impl<D: DataMut> FillUniform for GGLWETensorKeyCompressed<D> {
66 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
67 self.keys
68 .iter_mut()
69 .for_each(|key: &mut GGLWESwitchingKeyCompressed<D>| key.fill_uniform(log_bound, source))
70 }
71}
72
73impl<D: DataRef> fmt::Display for GGLWETensorKeyCompressed<D> {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 writeln!(f, "(GLWETensorKeyCompressed)",)?;
76 for (i, key) in self.keys.iter().enumerate() {
77 write!(f, "{i}: {key}")?;
78 }
79 Ok(())
80 }
81}
82
83impl GGLWETensorKeyCompressed<Vec<u8>> {
84 pub fn alloc<A>(infos: &A) -> Self
85 where
86 A: GGLWELayoutInfos,
87 {
88 assert_eq!(
89 infos.rank_in(),
90 infos.rank_out(),
91 "rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
92 );
93 Self::alloc_with(
94 infos.n(),
95 infos.base2k(),
96 infos.k(),
97 infos.rows(),
98 infos.digits(),
99 infos.rank_out(),
100 )
101 }
102
103 pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
104 let mut keys: Vec<GGLWESwitchingKeyCompressed<Vec<u8>>> = Vec::new();
105 let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
106 (0..pairs).for_each(|_| {
107 keys.push(GGLWESwitchingKeyCompressed::alloc_with(
108 n,
109 base2k,
110 k,
111 rows,
112 digits,
113 Rank(1),
114 rank,
115 ));
116 });
117 Self { keys }
118 }
119
120 pub fn alloc_bytes<A>(infos: &A) -> usize
121 where
122 A: GGLWELayoutInfos,
123 {
124 assert_eq!(
125 infos.rank_in(),
126 infos.rank_out(),
127 "rank_in != rank_out is not supported for GGLWETensorKeyCompressed"
128 );
129 let rank_out: usize = infos.rank_out().into();
130 let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
131 pairs
132 * GGLWESwitchingKeyCompressed::alloc_bytes_with(
133 infos.n(),
134 infos.base2k(),
135 infos.k(),
136 infos.rows(),
137 infos.digits(),
138 Rank(1),
139 infos.rank_out(),
140 )
141 }
142
143 pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
144 let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize;
145 pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, digits, Rank(1), rank)
146 }
147}
148
149impl<D: DataMut> ReaderFrom for GGLWETensorKeyCompressed<D> {
150 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
151 let len: usize = reader.read_u64::<LittleEndian>()? as usize;
152 if self.keys.len() != len {
153 return Err(std::io::Error::new(
154 std::io::ErrorKind::InvalidData,
155 format!("self.keys.len()={} != read len={}", self.keys.len(), len),
156 ));
157 }
158 for key in &mut self.keys {
159 key.read_from(reader)?;
160 }
161 Ok(())
162 }
163}
164
165impl<D: DataRef> WriterTo for GGLWETensorKeyCompressed<D> {
166 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
167 writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
168 for key in &self.keys {
169 key.write_to(writer)?;
170 }
171 Ok(())
172 }
173}
174
175impl<D: DataMut> GGLWETensorKeyCompressed<D> {
176 pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed<D> {
177 if i > j {
178 std::mem::swap(&mut i, &mut j);
179 };
180 let rank: usize = self.rank_out().into();
181 &mut self.keys[i * rank + j - (i * (i + 1) / 2)]
182 }
183}
184
185impl<D: DataMut, DR: DataRef, B: Backend> Decompress<B, GGLWETensorKeyCompressed<DR>> for GGLWETensorKey<D>
186where
187 Module<B>: VecZnxFillUniform + VecZnxCopy,
188{
189 fn decompress(&mut self, module: &Module<B>, other: &GGLWETensorKeyCompressed<DR>) {
190 #[cfg(debug_assertions)]
191 {
192 assert_eq!(
193 self.keys.len(),
194 other.keys.len(),
195 "invalid receiver: self.keys.len()={} != other.keys.len()={}",
196 self.keys.len(),
197 other.keys.len()
198 );
199 }
200
201 self.keys
202 .iter_mut()
203 .zip(other.keys.iter())
204 .for_each(|(a, b)| {
205 a.decompress(module, b);
206 });
207 }
208}