use bit_vec::BitVec;
use falcon_profiler::profiling;
use itertools::Itertools;
use num::Integer;
#[profiling]
pub(crate) fn compress(v: &[i16], byte_length: usize) -> Option<Vec<u8>> {
let lengths_and_coefficients = v.iter().map(|c| compress_coefficient(*c)).collect_vec();
let total_length = lengths_and_coefficients
.iter()
.map(|(l, _c)| *l)
.sum::<usize>();
if total_length > byte_length * 8 {
return None;
}
if v.is_empty() {
return None;
}
let mut bytes = vec![0u8; byte_length];
let mut counter = 0;
for (length, coefficient) in lengths_and_coefficients.iter().take(v.len() - 1) {
let (cdiv8, cmod8) = counter.div_mod_floor(&8);
bytes[cdiv8] |= coefficient >> cmod8;
bytes[cdiv8 + 1] |= ((*coefficient as u16) << (8 - cmod8)) as u8;
let (cldiv8, clmod8) = (counter + length - 1).div_mod_floor(&8);
bytes[cldiv8] |= 128u8 >> clmod8;
bytes[cldiv8 + 1] |= (128u16 << (8 - clmod8)) as u8;
counter += length;
}
let (length, coefficient) = lengths_and_coefficients.last().unwrap();
{
let (cdiv8, cmod8) = counter.div_mod_floor(&8);
bytes[cdiv8] |= coefficient >> cmod8;
bytes[cdiv8 + 1] |= ((*coefficient as u16) << (8 - cmod8)) as u8;
let (cldiv8, clmod8) = (counter + length - 1).div_mod_floor(&8);
bytes[cldiv8] |= 128u8 >> clmod8;
if cldiv8 + 1 < byte_length {
bytes[cldiv8 + 1] |= (128u16 << (8 - clmod8)) as u8;
} else if (128u16 << (8 - clmod8)) as u8 != 0 {
return None;
}
counter += length;
}
Some(bytes)
}
fn compress_coefficient(coeff: i16) -> (usize, u8) {
let sign = (coeff < 0) as u8;
let abs = coeff.unsigned_abs();
let low = abs as u8 & 127;
let high = abs >> 7;
(1 + 7 + high as usize + 1, ((sign << 7) | low))
}
#[allow(dead_code)]
pub(crate) fn decompress_slow(x: &[u8], n: usize) -> Option<Vec<i16>> {
let bitvector = BitVec::from_bytes(x);
let mut index = 0;
let mut result = Vec::with_capacity(n);
for _ in 0..n {
if index + 8 >= bitvector.len() {
return None;
}
let sign = if bitvector[index] { -1 } else { 1 };
index += 1;
let mut low_bits = 0i16;
for _ in 0..7 {
low_bits = (low_bits << 1) | if bitvector[index] { 1 } else { 0 };
index += 1;
}
let mut high_bits = 0;
while !bitvector[index] {
index += 1;
high_bits += 1;
}
index += 1;
let integer = sign * ((high_bits << 7) | low_bits);
result.push(integer);
}
Some(result)
}
#[profiling]
pub(crate) fn decompress(x: &[u8], n: usize) -> Option<Vec<i16>> {
let mut result = Vec::with_capacity(n);
let mut i = 0usize; let mut acc: u32 = 0;
let mut acc_len: u32 = 0;
for _ in 0..n {
if i >= x.len() {
return None;
}
acc = (acc << 8) | x[i] as u32;
i += 1;
let s = (acc >> (acc_len + 7)) & 1; let mut m = (acc >> acc_len) & 0x7F;
loop {
if acc_len == 0 {
if i >= x.len() {
return None;
}
acc = (acc << 8) | x[i] as u32;
i += 1;
acc_len = 8;
}
acc_len -= 1;
if (acc >> acc_len) & 1 != 0 {
break; }
m += 0x80;
if m > 2047 {
return None;
}
}
if s & (m.wrapping_sub(1) >> 31) != 0 {
return None;
}
let sw = s.wrapping_neg();
result.push(((m ^ sw).wrapping_sub(sw)) as i16);
}
if acc_len > 0 && acc & ((1 << acc_len) - 1) != 0 {
return None;
}
for &byte in &x[i..] {
if byte != 0 {
return None;
}
}
Some(result)
}
#[cfg(test)]
mod test {
use crate::encoding::{compress, decompress, decompress_slow};
use crate::falcon_field::Q;
use bit_vec::BitVec;
use itertools::Itertools;
use rand::distr::Distribution;
use rand::{rng, RngExt};
use proptest::prelude::*;
#[allow(dead_code)]
pub(crate) fn compress_slow(v: &[i16], slen: usize) -> Option<Vec<u8>> {
let mut bitvector: BitVec = BitVec::with_capacity(slen);
for coeff in v {
bitvector.push(*coeff < 0);
let s = (*coeff).abs();
for i in (0..7).rev() {
bitvector.push(((s >> i) & 1) != 0);
}
for _ in 0..(s >> 7) {
bitvector.push(false);
}
bitvector.push(true);
}
if bitvector.len() > slen {
return None;
}
while bitvector.len() < slen {
bitvector.push(false);
}
Some(bitvector.to_bytes())
}
fn short_elements(n: usize) -> Vec<i16> {
let sigma = 1.5 * f64::from(Q).sqrt();
let distribution = rand_distr::Normal::<f64>::new(0.0, sigma).unwrap();
let mut rng = rng();
(0..n)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>()
}
proptest! {
#[test]
fn compress_does_not_crash(v in (0..2000usize).prop_map(short_elements)) {
compress(&v, 2*v.len());
}
}
proptest! {
#[test]
fn decompress_recovers(v in (0..2000usize).prop_map(short_elements)) {
let slen = 2 * v.len();
let n = v.len();
if let Some(compressed) = compress(&v, slen) {
let recovered = decompress(&compressed, n).unwrap();
prop_assert_eq!(v, recovered.clone());
let recompressed = compress(&recovered, slen).unwrap();
prop_assert_eq!(compressed, recompressed);
}
}
}
#[test]
fn compress_empty_vec_does_not_crash() {
compress(&[], 0);
}
#[test]
fn test_compress_decompress() {
let num_iterations = 1000;
let sigma = 1.5 * f64::from(Q).sqrt();
let distribution = rand_distr::Normal::<f64>::new(0.0, sigma).unwrap();
let mut rng = rng();
let mut num_successes_512 = 0;
let mut num_successes_1024 = 0;
for _ in 0..num_iterations {
const SALT_LEN: usize = 40;
const HEAD_LEN: usize = 1;
{
const N: usize = 512;
const SIG_BYTELEN: usize = 666;
let slen = SIG_BYTELEN - SALT_LEN - HEAD_LEN;
let initial: [i16; N] = (0..N)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>()
.try_into()
.unwrap();
if let Some(compressed) = compress(&initial, slen * 8) {
if let Some(decompressed) = decompress(&compressed, N) {
assert_eq!(initial.to_vec(), decompressed);
num_successes_512 += 1;
}
}
}
{
const N: usize = 1024;
const SIG_BYTELEN: usize = 1280;
let slen = SIG_BYTELEN - SALT_LEN - HEAD_LEN;
let initial: [i16; 1024] = (0..N)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>()
.try_into()
.unwrap();
if let Some(compressed) = compress(&initial, slen * 8) {
if let Some(decompressed) = decompress(&compressed, N) {
assert_eq!(initial.to_vec(), decompressed);
num_successes_1024 += 1;
}
}
}
}
assert!((num_successes_512 as f64) / (num_iterations as f64) > 0.995);
assert!((num_successes_1024 as f64) / (num_iterations as f64) > 0.995);
}
#[test]
fn test_compress_equiv() {
let sigma = 1.5 * f64::from(Q).sqrt();
let distribution = rand_distr::Normal::<f64>::new(0.0, sigma).unwrap();
let mut rng = rng();
let n = 200;
let initial = (0..n)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>();
let slen = 2 * n * 8;
let compressed = compress_slow(&initial, slen).unwrap();
let compressed_fast = compress(&initial, slen / 8).unwrap();
assert_eq!(
compressed,
compressed_fast,
"\n{:#?}\n{:#?}",
BitVec::from_bytes(&compressed),
BitVec::from_bytes(&compressed_fast)
);
}
#[test]
fn test_decompress_equiv() {
let sigma = 1.5 * f64::from(Q).sqrt();
let distribution = rand_distr::Normal::<f64>::new(0.0, sigma).unwrap();
let mut rng = rng();
let num_iterations = 1000;
for _ in 0..num_iterations {
let n = rng.random_range(1..100);
let initial = (0..n)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>();
let slen = 2 * n * 8;
let compressed = compress(&initial, slen).unwrap();
let decompressed = decompress(&compressed, n);
let decompressed_fast = decompress_slow(&compressed, n);
assert_eq!(decompressed, decompressed_fast);
}
}
#[test]
fn test_decompress_failures() {
let sigma = 1.5 * f64::from(Q).sqrt();
let distribution = rand_distr::Normal::<f64>::new(0.0, sigma).unwrap();
let mut rng = rng();
let num_iterations = 1000;
for _ in 0..num_iterations {
let n = rng.random_range(1..100);
let initial = (0..n)
.map(|_| (distribution.sample(&mut rng) + 0.5).floor() as i16)
.collect::<Vec<_>>();
let slen = 2 * n * 8;
let compressed = compress(&initial, slen).unwrap();
assert!(decompress(&compressed, n + 1).is_none());
let mut compressed_bitvec = BitVec::from_bytes(&compressed);
let mut index = compressed_bitvec.len();
while !compressed_bitvec.get(index - 1).unwrap() {
index -= 1;
}
compressed_bitvec.set(index - 1, false);
let last_bit_flipped = compressed_bitvec.to_bytes();
assert!(decompress(&last_bit_flipped, n).is_none());
let mut random = compressed.iter().map(|_| rng.random::<u8>()).collect_vec();
let num_trailing_zeros = compressed
.iter()
.cloned()
.rev()
.find_position(|&x| x != 0)
.map(|(pos, _val)| pos)
.unwrap_or(0);
let len = random.len();
for i in 0..num_trailing_zeros {
random[len - 1 - i] = 0;
}
if let Some(decompressed) = decompress(&random, n) {
let recompressed = compress(&decompressed, slen).unwrap();
assert_eq!(
random,
recompressed,
"decompressed: {:?}\ndifference: {:?}",
decompressed,
random
.iter()
.enumerate()
.zip(recompressed.iter().enumerate())
.filter(|((_rai, rav), (_rei, rev))| rav != rev)
.map(|((rai, rav), (_rei, rev))| format!("{}. {} vs {}", rai, rav, rev))
.join(" ")
);
}
}
}
}