fn_dsa_comm/
codec.rs

1/// Encode small integers into bytes, with a fixed size per value.
2///
3/// Encode the provided sequence of signed integers `f`, with `nbits` bits per
4/// value, into the destination buffer `d`. The actual number of written bytes
5/// is returned. If the total encoded size is not an integral number of bytes,
6/// then extra padding bits of value 0 are used.
7pub fn trim_i8_encode(f: &[i8], nbits: u32, d: &mut [u8]) -> usize {
8    let mut k = 0;
9    let mut acc = 0;
10    let mut acc_len = 0;
11    let mask = (1u32 << nbits) - 1;
12    for i in 0..f.len() {
13        acc = (acc << nbits) | (((f[i] as u8) as u32) & mask);
14        acc_len += nbits;
15        while acc_len >= 8 {
16            acc_len -= 8;
17            d[k] = (acc >> acc_len) as u8;
18            k += 1;
19        }
20    }
21    if acc_len > 0 {
22        d[k] = (acc << (8 - acc_len)) as u8;
23        k += 1;
24    }
25    k
26}
27
28/// Decode small integers from bytes, with a fixed size per value.
29///
30/// Decode the provided bytes `d` into the signed integers `f`, using
31/// `nbits` bits per value. Exactly as many bytes as necessary are read
32/// from `d` in order to fill the slice `f` entirely. The actual number
33/// of bytes read from `d` is returned. `None` is returned if any of the
34/// following happens:
35/// 
36///  - Source buffer is not large enough.
37///  - An invalid encoding (`-2^(nbits-1)`) is encountered.
38///  - Some bits are unused in the last byte and are not all zero.
39/// 
40/// The number of bits per coefficient (nbits) MUST lie between 2 and 8
41/// (inclusive).
42pub fn trim_i8_decode(d: &[u8], f: &mut [i8], nbits: u32) -> Option<usize> {
43    let n = f.len();
44    let needed = ((n * (nbits as usize)) + 7) >> 3;
45    if d.len() < needed {
46        return None;
47    }
48    let mut j = 0;
49    let mut acc = 0;
50    let mut acc_len = 0;
51    let mask1 = (1 << nbits) - 1;
52    let mask2 = 1 << (nbits - 1);
53    for i in 0..needed {
54        acc = (acc << 8) | (d[i] as u32);
55        acc_len += 8;
56        while acc_len >= nbits {
57            acc_len -= nbits;
58            let w = (acc >> acc_len) & mask1;
59            let w = w | (w & mask2).wrapping_neg();
60            if w == mask2.wrapping_neg() {
61                return None;
62            }
63            f[j] = w as i8;
64            j += 1;
65            if j >= n {
66                break;
67            }
68        }
69    }
70    if (acc & ((1u32 << acc_len) - 1)) != 0 {
71        // Some of the extra bits are non-zero.
72        return None;
73    }
74    Some(needed)
75}
76
77/// Encode integers modulo 12289 into bytes, with 14 bits per value.
78///
79/// Encode the provided sequence of integers modulo q = 12289 into the
80/// destination buffer `d`. Exactly 14 bits are used for each value.
81/// The values MUST be in the `[0,q-1]` range. The number of source values
82/// MUST be a multiple of 4.
83pub fn modq_encode(h: &[u16], d: &mut [u8]) -> usize {
84    assert!((h.len() & 3) == 0);
85    let mut j = 0;
86    for i in 0..(h.len() >> 2) {
87        let x0 = h[4 * i + 0] as u64;
88        let x1 = h[4 * i + 1] as u64;
89        let x2 = h[4 * i + 2] as u64;
90        let x3 = h[4 * i + 3] as u64;
91        let x = (x0 << 42) | (x1 << 28) | (x2 << 14) | x3;
92        d[j..(j + 7)].copy_from_slice(&x.to_be_bytes()[1..8]);
93        j += 7;
94    }
95    j
96}
97
98/// Decode integers modulo 12289 from bytes, with 14 bits per value.
99///
100/// Decode some bytes into integers modulo q = 12289. Exactly as many
101/// bytes as necessary are read from the source `d` to fill all values in
102/// the destination slice `h`. The number of elements in `h` MUST be a
103/// multiple of 4. The total number of read bytes is returned. If the
104/// source is too short, of if any of the decoded values is invalid (i.e.
105/// not in the `[0,q-1]` range), then this function returns `None`.
106pub fn modq_decode(d: &[u8], h: &mut [u16]) -> Option<usize> {
107    let n = h.len();
108    if n == 0 {
109        return Some(0);
110    }
111    assert!((n & 3) == 0);
112    let needed = 7 * (n >> 2);
113    if d.len() != needed {
114        return None;
115    }
116    let mut ov = 0xFFFF;
117    let x = ((d[0] as u64) << 48)
118        | ((d[1] as u64) << 40)
119        | ((d[2] as u64) << 32)
120        | ((d[3] as u64) << 24)
121        | ((d[4] as u64) << 16)
122        | ((d[5] as u64) << 8)
123        | (d[6] as u64);
124    let h0 = ((x >> 42) as u32) & 0x3FFF;
125    let h1 = ((x >> 28) as u32) & 0x3FFF;
126    let h2 = ((x >> 14) as u32) & 0x3FFF;
127    let h3 = (x as u32) & 0x3FFF;
128    ov &= h0.wrapping_sub(12289);
129    ov &= h1.wrapping_sub(12289);
130    ov &= h2.wrapping_sub(12289);
131    ov &= h3.wrapping_sub(12289);
132    h[0] = h0 as u16;
133    h[1] = h1 as u16;
134    h[2] = h2 as u16;
135    h[3] = h3 as u16;
136    for i in 1..(n >> 2) {
137        let x = u64::from_be_bytes(
138            *<&[u8; 8]>::try_from(&d[(7 * i - 1)..(7 * i + 7)]).unwrap());
139        let h0 = ((x >> 42) as u32) & 0x3FFF;
140        let h1 = ((x >> 28) as u32) & 0x3FFF;
141        let h2 = ((x >> 14) as u32) & 0x3FFF;
142        let h3 = (x as u32) & 0x3FFF;
143        ov &= h0.wrapping_sub(12289);
144        ov &= h1.wrapping_sub(12289);
145        ov &= h2.wrapping_sub(12289);
146        ov &= h3.wrapping_sub(12289);
147        h[4 * i + 0] = h0 as u16;
148        h[4 * i + 1] = h1 as u16;
149        h[4 * i + 2] = h2 as u16;
150        h[4 * i + 3] = h3 as u16;
151    }
152    if (ov & 0x8000) == 0 {
153        return None;
154    }
155    Some(needed)
156}
157
158/// Encode small integers into bytes using a compressed (Golomb-Rice) format.
159///
160/// Encode the provided source values `s` with compressed encoding. If
161/// any of the source values is larger than 2047 (in absolute value),
162/// then this function returns `false`. If the destination buffer `d` is
163/// not large enough, then this function returns `false`. Otherwise, all
164/// output buffer bytes are set (padding bits/bytes of value zero are
165/// appended if necessary) and this function returns `true`.
166pub fn comp_encode(s: &[i16], d: &mut [u8]) -> bool {
167    let mut acc = 0;
168    let mut acc_len = 0;
169    let mut j = 0;
170    for i in 0..s.len() {
171        // Invariant: acc_len <= 7 at the beginning of each iteration.
172
173        let x = s[i] as i32;
174        if x < -2047 || x > 2047 {
175            return false;
176        }
177
178        // Get sign and absolute value.
179        let sw = (x >> 16) as u32;
180        let w = ((x as u32) ^ sw).wrapping_sub(sw);
181
182        // Encode sign bit then low 7 bits of the absolute value.
183        acc <<= 8;
184        acc |= sw & 0x80;
185        acc |= w & 0x7F;
186        acc_len += 8;
187
188        // Encode the high bits. Since |x| <= 2047, the value in the high
189        // bits is at most 15.
190        let wh = w >> 7;
191        acc <<= wh + 1;
192        acc |= 1;
193        acc_len += wh + 1;
194
195        // We appended at most 8 + 15 + 1 = 24 bits, so the total number of
196        // bits still fits in the 32-bit accumulator. We output complete
197        // bytes.
198        while acc_len >= 8 {
199            acc_len -= 8;
200            if j >= d.len() {
201                return false;
202            }
203            d[j] = (acc >> acc_len) as u8;
204            j += 1;
205        }
206    }
207
208    // Flush remaining bits (if any).
209    if acc_len > 0 {
210        if j >= d.len() {
211            return false;
212        }
213        d[j] = (acc << (8 - acc_len)) as u8;
214        j += 1;
215    }
216
217    // Pad with zeros.
218    for k in j..d.len() {
219        d[k] = 0;
220    }
221    true
222}
223
224/// Encode small integers from bytes using a compressed (Golomb-Rice) format.
225///
226/// Decode the provided source buffer `d` into signed integers `v`, using
227/// the compressed encoding convention. This function returns `false` in
228/// any of the following cases:
229///
230///  - Source does not contain enough encoded integers to fill `v` entirely.
231///  - An invalid encoding for a value is encountered.
232///  - Any of the remaining unused bits in `d` (after all integers have been
233///    decoded) is non-zero.
234///
235/// Valid encodings cover exactly the integers in the `[-2047,+2047]` range.
236/// For a given sequence of integers, there is only one valid encoding as
237/// a sequence of bytes (of a given length).
238pub fn comp_decode(d: &[u8], v: &mut [i16]) -> bool {
239    let mut i = 0;
240    let mut acc = 0;
241    let mut acc_len = 0;
242    for j in 0..v.len() {
243        // Invariant: acc_len <= 7 at the beginning of each iteration.
244
245        // Get next 8 bits and split them into sign bit (s) and low bits
246        // of the absolute value (m).
247        if i >= d.len() {
248            return false;
249        }
250        acc = (acc << 8) | (d[i] as u32);
251        i += 1;
252        let s = (acc >> (acc_len + 7)) & 1;
253        let mut m = (acc >> acc_len) & 0x7F;
254
255        // Get next bits until a 1 is reached.
256        loop {
257            if acc_len == 0 {
258                if i >= d.len() {
259                    return false;
260                }
261                acc = (acc << 8) | (d[i] as u32);
262                i += 1;
263                acc_len = 8;
264            }
265            acc_len -= 1;
266            if ((acc >> acc_len) & 1) != 0 {
267                break;
268            }
269            m += 0x80;
270            if m > 2047 {
271                return false;
272            }
273        }
274
275        // Reject "-0" (invalid encoding).
276        if (s & (m.wrapping_sub(1) >> 31)) != 0 {
277            return false;
278        }
279
280        // Apply the sign to get the value.
281        let sw = s.wrapping_neg();
282        let w = (m ^ sw).wrapping_sub(sw);
283        v[j] = w as i16;
284    }
285
286    // Check that unused bits are all zero.
287    if acc_len > 0 {
288        if (acc & ((1 << acc_len) - 1)) != 0 {
289            return false;
290        }
291    }
292    for k in i..d.len() {
293        if d[k] != 0 {
294            return false;
295        }
296    }
297    true
298}