Skip to main content

falcon/
codec.rs

1//! Encoding/decoding for Falcon keys and signatures.
2//! Ported from codec.c.
3
4// ======================================================================
5// modq: 14-bit packed encoding for values mod q = 12289
6// ======================================================================
7
8/// Encode a polynomial of mod-q values into packed 14-bit format.
9/// Returns the number of bytes written, or 0 on error.
10/// If `out` is `None`, returns the required output length.
11pub fn modq_encode(out: Option<&mut [u8]>, x: &[u16], logn: u32) -> usize {
12    let n: usize = 1 << logn;
13    for u in 0..n {
14        if x[u] >= 12289 {
15            return 0;
16        }
17    }
18    let out_len = ((n * 14) + 7) >> 3;
19    let buf = match out {
20        None => return out_len,
21        Some(b) => {
22            if out_len > b.len() {
23                return 0;
24            }
25            b
26        }
27    };
28    let mut acc: u32 = 0;
29    let mut acc_len: i32 = 0;
30    let mut pos = 0usize;
31    for u in 0..n {
32        acc = (acc << 14) | x[u] as u32;
33        acc_len += 14;
34        while acc_len >= 8 {
35            acc_len -= 8;
36            buf[pos] = (acc >> acc_len) as u8;
37            pos += 1;
38        }
39    }
40    if acc_len > 0 {
41        buf[pos] = (acc << (8 - acc_len)) as u8;
42    }
43    out_len
44}
45
46/// Decode packed 14-bit mod-q values into a polynomial.
47/// Returns the number of bytes consumed, or 0 on error.
48pub fn modq_decode(x: &mut [u16], logn: u32, input: &[u8]) -> usize {
49    let n: usize = 1 << logn;
50    let in_len = ((n * 14) + 7) >> 3;
51    if in_len > input.len() {
52        return 0;
53    }
54    let mut acc: u32 = 0;
55    let mut acc_len: i32 = 0;
56    let mut u: usize = 0;
57    let mut buf_pos: usize = 0;
58    while u < n {
59        acc = (acc << 8) | input[buf_pos] as u32;
60        buf_pos += 1;
61        acc_len += 8;
62        if acc_len >= 14 {
63            acc_len -= 14;
64            let w = (acc >> acc_len) & 0x3FFF;
65            if w >= 12289 {
66                return 0;
67            }
68            x[u] = w as u16;
69            u += 1;
70        }
71    }
72    if (acc & (((1u32) << acc_len) - 1)) != 0 {
73        return 0;
74    }
75    in_len
76}
77
78// ======================================================================
79// trim_i16: variable-width signed 16-bit encoding
80// ======================================================================
81
82/// Encode signed 16-bit integers with a given bit width.
83/// Returns bytes written, or 0 on error.
84/// If `out` is `None`, returns the required output length.
85pub fn trim_i16_encode(out: Option<&mut [u8]>, x: &[i16], logn: u32, bits: u32) -> usize {
86    let n: usize = 1 << logn;
87    let maxv = (1i32 << (bits - 1)) - 1;
88    let minv = -maxv;
89    for u in 0..n {
90        if (x[u] as i32) < minv || (x[u] as i32) > maxv {
91            return 0;
92        }
93    }
94    let out_len = ((n * bits as usize) + 7) >> 3;
95    let buf = match out {
96        None => return out_len,
97        Some(b) => {
98            if out_len > b.len() {
99                return 0;
100            }
101            b
102        }
103    };
104    let mut acc: u32 = 0;
105    let mut acc_len: u32 = 0;
106    let mask: u32 = (1u32 << bits) - 1;
107    let mut pos = 0usize;
108    for u in 0..n {
109        acc = (acc << bits) | ((x[u] as u16) as u32 & mask);
110        acc_len += bits;
111        while acc_len >= 8 {
112            acc_len -= 8;
113            buf[pos] = (acc >> acc_len) as u8;
114            pos += 1;
115        }
116    }
117    if acc_len > 0 {
118        buf[pos] = (acc << (8 - acc_len)) as u8;
119    }
120    out_len
121}
122
123/// Decode variable-width signed 16-bit integers.
124/// Returns bytes consumed, or 0 on error.
125pub fn trim_i16_decode(x: &mut [i16], logn: u32, bits: u32, input: &[u8]) -> usize {
126    let n: usize = 1 << logn;
127    let in_len = ((n * bits as usize) + 7) >> 3;
128    if in_len > input.len() {
129        return 0;
130    }
131    let mut u: usize = 0;
132    let mut acc: u32 = 0;
133    let mut acc_len: u32 = 0;
134    let mask1: u32 = (1u32 << bits) - 1;
135    let mask2: u32 = 1u32 << (bits - 1);
136    let mut buf_pos: usize = 0;
137    while u < n {
138        acc = (acc << 8) | input[buf_pos] as u32;
139        buf_pos += 1;
140        acc_len += 8;
141        while acc_len >= bits && u < n {
142            acc_len -= bits;
143            let mut w: u32 = (acc >> acc_len) & mask1;
144            w |= (w & mask2).wrapping_neg();
145            if w == mask2.wrapping_neg() {
146                // The -2^(bits-1) value is forbidden.
147                return 0;
148            }
149            w |= (w & mask2).wrapping_neg();
150            x[u] = w as i32 as i16;
151            u += 1;
152        }
153    }
154    if (acc & ((1u32 << acc_len) - 1)) != 0 {
155        // Extra bits in the last byte must be zero.
156        return 0;
157    }
158    in_len
159}
160
161// ======================================================================
162// trim_i8: variable-width signed 8-bit encoding
163// ======================================================================
164
165/// Encode signed 8-bit integers with a given bit width.
166/// Returns bytes written, or 0 on error.
167/// If `out` is `None`, returns the required output length.
168pub fn trim_i8_encode(out: Option<&mut [u8]>, x: &[i8], logn: u32, bits: u32) -> usize {
169    let n: usize = 1 << logn;
170    let maxv = (1i32 << (bits - 1)) - 1;
171    let minv = -maxv;
172    for u in 0..n {
173        if (x[u] as i32) < minv || (x[u] as i32) > maxv {
174            return 0;
175        }
176    }
177    let out_len = ((n * bits as usize) + 7) >> 3;
178    let buf = match out {
179        None => return out_len,
180        Some(b) => {
181            if out_len > b.len() {
182                return 0;
183            }
184            b
185        }
186    };
187    let mut acc: u32 = 0;
188    let mut acc_len: u32 = 0;
189    let mask: u32 = (1u32 << bits) - 1;
190    let mut pos = 0usize;
191    for u in 0..n {
192        acc = (acc << bits) | ((x[u] as u8) as u32 & mask);
193        acc_len += bits;
194        while acc_len >= 8 {
195            acc_len -= 8;
196            buf[pos] = (acc >> acc_len) as u8;
197            pos += 1;
198        }
199    }
200    if acc_len > 0 {
201        buf[pos] = (acc << (8 - acc_len)) as u8;
202    }
203    out_len
204}
205
206/// Decode variable-width signed 8-bit integers.
207/// Returns bytes consumed, or 0 on error.
208pub fn trim_i8_decode(x: &mut [i8], logn: u32, bits: u32, input: &[u8]) -> usize {
209    let n: usize = 1 << logn;
210    let in_len = ((n * bits as usize) + 7) >> 3;
211    if in_len > input.len() {
212        return 0;
213    }
214    let mut u: usize = 0;
215    let mut acc: u32 = 0;
216    let mut acc_len: u32 = 0;
217    let mask1: u32 = (1u32 << bits) - 1;
218    let mask2: u32 = 1u32 << (bits - 1);
219    let mut buf_pos: usize = 0;
220    while u < n {
221        acc = (acc << 8) | input[buf_pos] as u32;
222        buf_pos += 1;
223        acc_len += 8;
224        while acc_len >= bits && u < n {
225            acc_len -= bits;
226            let mut w: u32 = (acc >> acc_len) & mask1;
227            w |= (w & mask2).wrapping_neg();
228            if w == mask2.wrapping_neg() {
229                // The -2^(bits-1) value is forbidden.
230                return 0;
231            }
232            x[u] = w as i32 as i8;
233            u += 1;
234        }
235    }
236    if (acc & ((1u32 << acc_len) - 1)) != 0 {
237        // Extra bits in the last byte must be zero.
238        return 0;
239    }
240    in_len
241}
242
243// ======================================================================
244// comp: variable-length compressed signature encoding
245// ======================================================================
246
247/// Encode signature coefficients using compressed format.
248/// Values must be in -2047..+2047 range.
249/// Returns bytes written, or 0 on error.
250/// If `out` is `None`, computes and returns the required length.
251pub fn comp_encode(mut out: Option<&mut [u8]>, x: &[i16], logn: u32) -> usize {
252    let n: usize = 1 << logn;
253
254    // Verify values within range.
255    for u in 0..n {
256        if x[u] < -2047 || x[u] > 2047 {
257            return 0;
258        }
259    }
260
261    let mut acc: u32 = 0;
262    let mut acc_len: u32 = 0;
263    let mut v: usize = 0;
264    for u in 0..n {
265        // Get sign and absolute value; push the sign bit.
266        acc <<= 1;
267        let mut t = x[u] as i32;
268        if t < 0 {
269            t = -t;
270            acc |= 1;
271        }
272        let mut w = t as u32;
273
274        // Push the low 7 bits of the absolute value.
275        acc <<= 7;
276        acc |= w & 127;
277        w >>= 7;
278
279        // We pushed exactly 8 bits.
280        acc_len += 8;
281
282        // Push as many zeros as necessary, then a one.
283        acc <<= w + 1;
284        acc |= 1;
285        acc_len += w + 1;
286
287        // Produce all full bytes.
288        while acc_len >= 8 {
289            acc_len -= 8;
290            if let Some(ref buf) = out {
291                if v >= buf.len() {
292                    return 0;
293                }
294            }
295            if let Some(ref mut buf) = out {
296                buf[v] = (acc >> acc_len) as u8;
297            }
298            v += 1;
299        }
300    }
301
302    // Flush remaining bits (if any).
303    if acc_len > 0 {
304        if let Some(ref buf) = out {
305            if v >= buf.len() {
306                return 0;
307            }
308        }
309        if let Some(ref mut buf) = out {
310            buf[v] = (acc << (8 - acc_len)) as u8;
311        }
312        v += 1;
313    }
314
315    v
316}
317
318/// Decode compressed signature coefficients.
319/// Returns bytes consumed, or 0 on error.
320pub fn comp_decode(x: &mut [i16], logn: u32, input: &[u8]) -> usize {
321    let n: usize = 1 << logn;
322    let max_in_len = input.len();
323    let mut acc: u32 = 0;
324    let mut acc_len: u32 = 0;
325    let mut v: usize = 0;
326    for u in 0..n {
327        // Get next eight bits: sign and low seven bits of the absolute value.
328        if v >= max_in_len {
329            return 0;
330        }
331        acc = (acc << 8) | input[v] as u32;
332        v += 1;
333        let b = acc >> acc_len;
334        let s = b & 128;
335        let mut m = b & 127;
336
337        // Get next bits until a 1 is reached.
338        loop {
339            if acc_len == 0 {
340                if v >= max_in_len {
341                    return 0;
342                }
343                acc = (acc << 8) | input[v] as u32;
344                v += 1;
345                acc_len = 8;
346            }
347            acc_len -= 1;
348            if ((acc >> acc_len) & 1) != 0 {
349                break;
350            }
351            m += 128;
352            if m > 2047 {
353                return 0;
354            }
355        }
356
357        // "-0" is forbidden.
358        if s != 0 && m == 0 {
359            return 0;
360        }
361
362        x[u] = if s != 0 { -(m as i32) } else { m as i32 } as i16;
363    }
364
365    // Unused bits in the last byte must be zero.
366    if (acc & ((1u32 << acc_len) - 1)) != 0 {
367        return 0;
368    }
369
370    v
371}
372
373// ======================================================================
374// Bit-width limits for key/signature elements (indexed by logn, 0..10)
375// ======================================================================
376
377/// Maximum number of bits for f, g coefficients.
378pub static MAX_FG_BITS: [u8; 11] = [
379    0, // unused
380    8, 8, 8, 8, 8, 7, 7, 6, 6, 5,
381];
382
383/// Maximum number of bits for F, G coefficients.
384pub static MAX_FG_BITS_UPPER: [u8; 11] = [
385    0, // unused
386    8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
387];
388
389/// Maximum number of bits for signature coefficients (including sign bit).
390pub static MAX_SIG_BITS: [u8; 11] = [
391    0, // unused
392    10, 11, 11, 12, 12, 12, 12, 12, 12, 12,
393];