fn_dsa_comm/codec.rs
1/// Encode small integers into bytes, with a fixed size per value.
2///
3/// Encode the provided sequence of signed integers `f`, with `nbits` bits per
4/// value, into the destination buffer `d`. The actual number of written bytes
5/// is returned. If the total encoded size is not an integral number of bytes,
6/// then extra padding bits of value 0 are used.
7pub fn trim_i8_encode(f: &[i8], nbits: u32, d: &mut [u8]) -> usize {
8 let mut k = 0;
9 let mut acc = 0;
10 let mut acc_len = 0;
11 let mask = (1u32 << nbits) - 1;
12 for i in 0..f.len() {
13 acc = (acc << nbits) | (((f[i] as u8) as u32) & mask);
14 acc_len += nbits;
15 while acc_len >= 8 {
16 acc_len -= 8;
17 d[k] = (acc >> acc_len) as u8;
18 k += 1;
19 }
20 }
21 if acc_len > 0 {
22 d[k] = (acc << (8 - acc_len)) as u8;
23 k += 1;
24 }
25 k
26}
27
28/// Decode small integers from bytes, with a fixed size per value.
29///
30/// Decode the provided bytes `d` into the signed integers `f`, using
31/// `nbits` bits per value. Exactly as many bytes as necessary are read
32/// from `d` in order to fill the slice `f` entirely. The actual number
33/// of bytes read from `d` is returned. `None` is returned if any of the
34/// following happens:
35///
36/// - Source buffer is not large enough.
37/// - An invalid encoding (`-2^(nbits-1)`) is encountered.
38/// - Some bits are unused in the last byte and are not all zero.
39///
40/// The number of bits per coefficient (nbits) MUST lie between 2 and 8
41/// (inclusive).
42pub fn trim_i8_decode(d: &[u8], f: &mut [i8], nbits: u32) -> Option<usize> {
43 let n = f.len();
44 let needed = ((n * (nbits as usize)) + 7) >> 3;
45 if d.len() < needed {
46 return None;
47 }
48 let mut j = 0;
49 let mut acc = 0;
50 let mut acc_len = 0;
51 let mask1 = (1 << nbits) - 1;
52 let mask2 = 1 << (nbits - 1);
53 for i in 0..needed {
54 acc = (acc << 8) | (d[i] as u32);
55 acc_len += 8;
56 while acc_len >= nbits {
57 acc_len -= nbits;
58 let w = (acc >> acc_len) & mask1;
59 let w = w | (w & mask2).wrapping_neg();
60 if w == mask2.wrapping_neg() {
61 return None;
62 }
63 f[j] = w as i8;
64 j += 1;
65 if j >= n {
66 break;
67 }
68 }
69 }
70 if (acc & ((1u32 << acc_len) - 1)) != 0 {
71 // Some of the extra bits are non-zero.
72 return None;
73 }
74 Some(needed)
75}
76
77/// Encode integers modulo 12289 into bytes, with 14 bits per value.
78///
79/// Encode the provided sequence of integers modulo q = 12289 into the
80/// destination buffer `d`. Exactly 14 bits are used for each value.
81/// The values MUST be in the `[0,q-1]` range. The number of source values
82/// MUST be a multiple of 4.
83pub fn modq_encode(h: &[u16], d: &mut [u8]) -> usize {
84 assert!((h.len() & 3) == 0);
85 let mut j = 0;
86 for i in 0..(h.len() >> 2) {
87 let x0 = h[4 * i + 0] as u64;
88 let x1 = h[4 * i + 1] as u64;
89 let x2 = h[4 * i + 2] as u64;
90 let x3 = h[4 * i + 3] as u64;
91 let x = (x0 << 42) | (x1 << 28) | (x2 << 14) | x3;
92 d[j..(j + 7)].copy_from_slice(&x.to_be_bytes()[1..8]);
93 j += 7;
94 }
95 j
96}
97
98/// Decode integers modulo 12289 from bytes, with 14 bits per value.
99///
100/// Decode some bytes into integers modulo q = 12289. Exactly as many
101/// bytes as necessary are read from the source `d` to fill all values in
102/// the destination slice `h`. The number of elements in `h` MUST be a
103/// multiple of 4. The total number of read bytes is returned. If the
104/// source is too short, of if any of the decoded values is invalid (i.e.
105/// not in the `[0,q-1]` range), then this function returns `None`.
106pub fn modq_decode(d: &[u8], h: &mut [u16]) -> Option<usize> {
107 let n = h.len();
108 if n == 0 {
109 return Some(0);
110 }
111 assert!((n & 3) == 0);
112 let needed = 7 * (n >> 2);
113 if d.len() != needed {
114 return None;
115 }
116 let mut ov = 0xFFFF;
117 let x = ((d[0] as u64) << 48)
118 | ((d[1] as u64) << 40)
119 | ((d[2] as u64) << 32)
120 | ((d[3] as u64) << 24)
121 | ((d[4] as u64) << 16)
122 | ((d[5] as u64) << 8)
123 | (d[6] as u64);
124 let h0 = ((x >> 42) as u32) & 0x3FFF;
125 let h1 = ((x >> 28) as u32) & 0x3FFF;
126 let h2 = ((x >> 14) as u32) & 0x3FFF;
127 let h3 = (x as u32) & 0x3FFF;
128 ov &= h0.wrapping_sub(12289);
129 ov &= h1.wrapping_sub(12289);
130 ov &= h2.wrapping_sub(12289);
131 ov &= h3.wrapping_sub(12289);
132 h[0] = h0 as u16;
133 h[1] = h1 as u16;
134 h[2] = h2 as u16;
135 h[3] = h3 as u16;
136 for i in 1..(n >> 2) {
137 let x = u64::from_be_bytes(
138 *<&[u8; 8]>::try_from(&d[(7 * i - 1)..(7 * i + 7)]).unwrap());
139 let h0 = ((x >> 42) as u32) & 0x3FFF;
140 let h1 = ((x >> 28) as u32) & 0x3FFF;
141 let h2 = ((x >> 14) as u32) & 0x3FFF;
142 let h3 = (x as u32) & 0x3FFF;
143 ov &= h0.wrapping_sub(12289);
144 ov &= h1.wrapping_sub(12289);
145 ov &= h2.wrapping_sub(12289);
146 ov &= h3.wrapping_sub(12289);
147 h[4 * i + 0] = h0 as u16;
148 h[4 * i + 1] = h1 as u16;
149 h[4 * i + 2] = h2 as u16;
150 h[4 * i + 3] = h3 as u16;
151 }
152 if (ov & 0x8000) == 0 {
153 return None;
154 }
155 Some(needed)
156}
157
158/// Encode small integers into bytes using a compressed (Golomb-Rice) format.
159///
160/// Encode the provided source values `s` with compressed encoding. If
161/// any of the source values is larger than 2047 (in absolute value),
162/// then this function returns `false`. If the destination buffer `d` is
163/// not large enough, then this function returns `false`. Otherwise, all
164/// output buffer bytes are set (padding bits/bytes of value zero are
165/// appended if necessary) and this function returns `true`.
166pub fn comp_encode(s: &[i16], d: &mut [u8]) -> bool {
167 let mut acc = 0;
168 let mut acc_len = 0;
169 let mut j = 0;
170 for i in 0..s.len() {
171 // Invariant: acc_len <= 7 at the beginning of each iteration.
172
173 let x = s[i] as i32;
174 if x < -2047 || x > 2047 {
175 return false;
176 }
177
178 // Get sign and absolute value.
179 let sw = (x >> 16) as u32;
180 let w = ((x as u32) ^ sw).wrapping_sub(sw);
181
182 // Encode sign bit then low 7 bits of the absolute value.
183 acc <<= 8;
184 acc |= sw & 0x80;
185 acc |= w & 0x7F;
186 acc_len += 8;
187
188 // Encode the high bits. Since |x| <= 2047, the value in the high
189 // bits is at most 15.
190 let wh = w >> 7;
191 acc <<= wh + 1;
192 acc |= 1;
193 acc_len += wh + 1;
194
195 // We appended at most 8 + 15 + 1 = 24 bits, so the total number of
196 // bits still fits in the 32-bit accumulator. We output complete
197 // bytes.
198 while acc_len >= 8 {
199 acc_len -= 8;
200 if j >= d.len() {
201 return false;
202 }
203 d[j] = (acc >> acc_len) as u8;
204 j += 1;
205 }
206 }
207
208 // Flush remaining bits (if any).
209 if acc_len > 0 {
210 if j >= d.len() {
211 return false;
212 }
213 d[j] = (acc << (8 - acc_len)) as u8;
214 j += 1;
215 }
216
217 // Pad with zeros.
218 for k in j..d.len() {
219 d[k] = 0;
220 }
221 true
222}
223
224/// Encode small integers from bytes using a compressed (Golomb-Rice) format.
225///
226/// Decode the provided source buffer `d` into signed integers `v`, using
227/// the compressed encoding convention. This function returns `false` in
228/// any of the following cases:
229///
230/// - Source does not contain enough encoded integers to fill `v` entirely.
231/// - An invalid encoding for a value is encountered.
232/// - Any of the remaining unused bits in `d` (after all integers have been
233/// decoded) is non-zero.
234///
235/// Valid encodings cover exactly the integers in the `[-2047,+2047]` range.
236/// For a given sequence of integers, there is only one valid encoding as
237/// a sequence of bytes (of a given length).
238pub fn comp_decode(d: &[u8], v: &mut [i16]) -> bool {
239 let mut i = 0;
240 let mut acc = 0;
241 let mut acc_len = 0;
242 for j in 0..v.len() {
243 // Invariant: acc_len <= 7 at the beginning of each iteration.
244
245 // Get next 8 bits and split them into sign bit (s) and low bits
246 // of the absolute value (m).
247 if i >= d.len() {
248 return false;
249 }
250 acc = (acc << 8) | (d[i] as u32);
251 i += 1;
252 let s = (acc >> (acc_len + 7)) & 1;
253 let mut m = (acc >> acc_len) & 0x7F;
254
255 // Get next bits until a 1 is reached.
256 loop {
257 if acc_len == 0 {
258 if i >= d.len() {
259 return false;
260 }
261 acc = (acc << 8) | (d[i] as u32);
262 i += 1;
263 acc_len = 8;
264 }
265 acc_len -= 1;
266 if ((acc >> acc_len) & 1) != 0 {
267 break;
268 }
269 m += 0x80;
270 if m > 2047 {
271 return false;
272 }
273 }
274
275 // Reject "-0" (invalid encoding).
276 if (s & (m.wrapping_sub(1) >> 31)) != 0 {
277 return false;
278 }
279
280 // Apply the sign to get the value.
281 let sw = s.wrapping_neg();
282 let w = (m ^ sw).wrapping_sub(sw);
283 v[j] = w as i16;
284 }
285
286 // Check that unused bits are all zero.
287 if acc_len > 0 {
288 if (acc & ((1 << acc_len) - 1)) != 0 {
289 return false;
290 }
291 }
292 for k in i..d.len() {
293 if d[k] != 0 {
294 return false;
295 }
296 }
297 true
298}