use crate::helpers::{bit_length, ensure, is_in_range};
use crate::types::{R, R0};
use crate::Q;
pub(crate) fn coeff_from_three_bytes<const CTEST: bool>(b: [u8; 3]) -> Result<i32, &'static str> {
let b2p = i32::from(b[2] & 0x7F);
let b2p = if CTEST { b2p & 0x3F } else { b2p };
let z = (b2p << 16) | (i32::from(b[1]) << 8) | i32::from(b[0]);
if z < Q {
Ok(z)
} else {
Err("Alg 14: returns ⊥")
}
}
pub(crate) fn coeff_from_half_byte<const CTEST: bool>(
eta: i32, b: u8,
) -> Result<i32, &'static str> {
const M5: i32 = ((1i32 << 24) / 5) + 1;
debug_assert!((eta == 2) || (eta == 4), "Alg 15: incorrect eta");
debug_assert!(b < 16, "Alg 15: b out of range");
let b = i32::from(if CTEST { b & 0x07 } else { b });
if (eta == 2) && (b < 15) {
let quot = (b * M5) >> 24;
let rem = b - quot * 5;
Ok(2 - rem)
} else {
if (eta == 4) && (b < 9) {
Ok(4 - b)
} else {
Err("Alg 15: returns ⊥")
}
}
}
pub(crate) fn simple_bit_pack(w: &R, b: i32, bytes_out: &mut [u8]) {
debug_assert!((1..1024 * 1024).contains(&b), "Alg 16: b out of range"); debug_assert!(is_in_range(w, 0, b), "Alg 16: w out of range"); debug_assert_eq!(bytes_out.len(), 32 * bit_length(b), "Alg 16: incorrect size of output bytes");
bit_pack(w, 0, b, bytes_out);
}
pub(crate) fn bit_pack(w: &R, a: i32, b: i32, bytes_out: &mut [u8]) {
debug_assert!((0..(1024 * 1024)).contains(&a), "Alg 17: a out of range");
debug_assert!((1..(1024 * 1024)).contains(&b), "Alg 17: b out of range");
debug_assert!(is_in_range(w, a, b), "Alg 17: w out of range");
debug_assert_eq!(w.0.len() * bit_length(a + b), bytes_out.len() * 8, "Alg 17: bad output size");
let bitlen = bit_length(a + b); let mut temp = 0u32; let mut byte_index = 0; let mut bit_index = 0;
for coeff in w.0 {
if a > 0 {
temp |= b.abs_diff(coeff) << bit_index;
} else {
temp |= coeff.unsigned_abs() << bit_index;
}
bit_index += bitlen;
while bit_index > 7 {
bytes_out[byte_index] = temp.to_le_bytes()[0];
temp >>= 8;
byte_index += 1;
bit_index -= 8;
}
}
}
pub(crate) fn simple_bit_unpack(v: &[u8], b: i32) -> Result<R, &'static str> {
debug_assert!((1..(1024 * 1024)).contains(&b), "Alg 18: b out of range");
debug_assert_eq!(v.len(), 32 * bit_length(b), "Alg 18: bad output size");
let w_out = bit_unpack(v, 0, b).map_err(|_| "Alg 18: w out of range")?;
Ok(w_out)
}
pub(crate) fn bit_unpack(v: &[u8], a: i32, b: i32) -> Result<R, &'static str> {
debug_assert!((0..(1024 * 1024)).contains(&a), "Alg 19: a out of range");
debug_assert!((1..(1024 * 1024)).contains(&b), "Alg 19: b out of range");
debug_assert_eq!(v.len(), 32 * bit_length(a + b), "Alg 19: bad output size");
let bitlen = bit_length(a + b).try_into().expect("Alg 19: try_into fail");
let mut w_out = R([0i32; 256]);
let mut temp = 0i32;
let mut r_index = 0;
let mut bit_index = 0;
for byte in v {
temp |= i32::from(*byte) << bit_index;
bit_index += 8;
while bit_index >= bitlen {
let tmask = temp & ((1 << bitlen) - 1);
w_out.0[r_index] = if a == 0 { tmask } else { b - tmask };
bit_index -= bitlen;
temp >>= bitlen;
r_index += 1;
}
}
let bot = i32::abs(b - (1 << bitlen) + 1); ensure!(is_in_range(&w_out, bot, b), "Alg 19: w out of range");
Ok(w_out)
}
pub(crate) fn hint_bit_pack<const CTEST: bool, const K: usize>(
omega: i32, h: &[R; K], y_bytes: &mut [u8],
) {
let omega_u = usize::try_from(omega).expect("Alg 20: try_from fail");
debug_assert!((1..256).contains(&(omega_u + K)), "Alg 20: omega+K out of range");
debug_assert_eq!(y_bytes.len(), omega_u + K, "Alg 20: bad output size");
debug_assert!(h.iter().all(|r| is_in_range(r, 0, 1)), "Alg 20: h not 0/1");
debug_assert!(
h.iter().all(|r| r.0.iter().filter(|&e| *e == 1).sum::<i32>() <= omega),
"Alg 20: too many 1's in h"
);
y_bytes.iter_mut().for_each(|e| *e = 0);
let mut index = 0;
for i in 0..K {
for j in 0..256 {
if CTEST && (index > (y_bytes.len() - 1)) {
continue;
};
if CTEST || (h[i].0[j] != 0) {
y_bytes[index] = j.to_le_bytes()[0];
index += 1;
}
}
y_bytes[omega_u + i] = index.to_le_bytes()[0];
}
}
pub(crate) fn hint_bit_unpack<const K: usize>(
omega: i32, y_bytes: &[u8],
) -> Result<[R; K], &'static str> {
let omega_u = usize::try_from(omega).expect("Alg 21: omega try_into fail");
debug_assert!((1..256).contains(&(omega_u + K)), "Alg 21: omega+K too large");
debug_assert_eq!(y_bytes.len(), omega_u + K, "Alg 21: bad output size");
let mut h: [R; K] = [R0; K];
let mut index = 0;
for i in 0..K {
if (y_bytes[omega_u + i] < index) || (y_bytes[omega_u + i] > omega.to_le_bytes()[0]) {
return Err("Alg 21a: returns ⊥ (4)");
}
let first = index;
while index < y_bytes[omega_u + i] {
if index > first {
if y_bytes[usize::from(index) - 1] >= y_bytes[usize::from(index)] {
return Err("Alg 21a: returns ⊥ (9)");
}
}
h[i].0[y_bytes[index as usize] as usize] = 1;
index += 1;
}
}
for i in index..omega.to_le_bytes()[0] {
if y_bytes[i as usize] != 0 {
return Err("Alg 21b: returns ⊥ (17");
}
}
debug_assert!(
h.iter().all(|r| r.0.iter().filter(|&&e| e == 1).sum::<i32>() <= omega),
"Alg 21: too many 1's in h"
);
Ok(h)
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::RngCore;
#[test]
fn test_coef_from_three_bytes1() {
let bytes = [0x12u8, 0x34, 0x56];
let res = coeff_from_three_bytes::<false>(bytes).unwrap();
assert_eq!(res, 0x0056_3412);
}
#[test]
fn test_coef_from_three_bytes2() {
let bytes = [0x12u8, 0x34, 0x80];
let res = coeff_from_three_bytes::<false>(bytes).unwrap();
assert_eq!(res, 0x0000_3412);
}
#[test]
fn test_coef_from_three_bytes3() {
let bytes = [0x01u8, 0xe0, 0x80];
let res = coeff_from_three_bytes::<false>(bytes).unwrap();
assert_eq!(res, 0x0000_e001);
}
#[test]
#[should_panic(expected = "panic: out of range")]
fn test_coef_from_three_bytes4() {
let bytes = [0x01u8, 0xe0, 0x7f];
let res = coeff_from_three_bytes::<false>(bytes).expect("panic: out of range");
assert_eq!(res, 0x0056_3412);
}
#[test]
fn test_coef_from_half_byte1() {
let inp = 3;
let res = coeff_from_half_byte::<false>(2, inp).unwrap();
assert_eq!(-1, res);
}
#[test]
fn test_coef_from_half_byte2() {
let inp = 8;
let res = coeff_from_half_byte::<false>(4, inp).unwrap();
assert_eq!(-4, res);
}
#[should_panic]
#[allow(clippy::should_panic_without_expect)]
#[test]
fn test_coef_from_half_byte_validation1() {
let inp = 22;
let res = coeff_from_half_byte::<false>(2, inp);
assert!(res.is_err());
}
#[should_panic]
#[allow(clippy::should_panic_without_expect)]
#[test]
fn test_coef_from_half_byte_validation2() {
let inp = 5;
let res = coeff_from_half_byte::<false>(1, inp);
assert!(res.is_err());
}
#[test]
fn test_coef_from_half_byte_validation3() {
let inp = 10;
let res = coeff_from_half_byte::<false>(4, inp);
assert!(res.is_err());
}
#[test]
fn test_simple_bit_pack_roundtrip() {
let mut random_bytes = [0u8; 32 * 6];
rand::thread_rng().fill_bytes(&mut random_bytes);
let r = simple_bit_unpack(&random_bytes, (1 << 6) - 1).unwrap();
let mut res = [0u8; 32 * 6];
simple_bit_pack(&r, (1 << 6) - 1, &mut res);
assert_eq!(random_bytes, res);
}
#[test]
#[should_panic]
#[allow(clippy::should_panic_without_expect)]
fn test_simple_bit_unpack_validation1() {
let mut random_bytes = [0u8; 32 * 7];
rand::thread_rng().fill_bytes(&mut random_bytes);
let res = simple_bit_unpack(&random_bytes, (1 << 6) - 1);
assert!(res.is_err());
}
#[test]
#[should_panic]
#[allow(clippy::should_panic_without_expect)]
fn test_bit_unpack_validation1() {
let mut random_bytes = [0u8; 32 * 7];
rand::thread_rng().fill_bytes(&mut random_bytes);
let res = bit_unpack(&random_bytes, 0, (1 << 6) - 1);
assert!(res.is_err());
}
#[test]
fn test_simple_bit_pack_validation1() {
let mut random_bytes = [0u8; 32 * 6];
rand::thread_rng().fill_bytes(&mut random_bytes);
let r = R([0i32; 256]);
simple_bit_pack(&r, (1 << 6) - 1, &mut random_bytes);
}
#[test]
#[should_panic(expected = "Alg 16: b out of range")]
fn test_simple_bit_pack_b_range() {
let w = R0; let mut bytes = [0u8; 32];
simple_bit_pack(&w, 0, &mut bytes); }
#[test]
#[should_panic(expected = "Alg 16: w out of range")]
fn test_simple_bit_pack_w_range() {
let mut w = R0; w.0[0] = 5;
let mut bytes = [0u8; 32];
simple_bit_pack(&w, 3, &mut bytes); }
#[test]
#[should_panic(expected = "Alg 16: incorrect size of output bytes")]
fn test_simple_bit_pack_output_size() {
let w = R0; let mut bytes = [0u8; 65]; simple_bit_pack(&w, 2, &mut bytes);
}
#[test]
#[should_panic(expected = "Alg 17: a out of range")]
fn test_bit_pack_a_range() {
let w = R0; let mut bytes = [0u8; 32];
bit_pack(&w, -1, 2, &mut bytes); }
#[test]
#[should_panic(expected = "Alg 17: b out of range")]
fn test_bit_pack_b_range() {
let w = R0; let mut bytes = [0u8; 32];
bit_pack(&w, 0, 0, &mut bytes); }
#[test]
#[should_panic(expected = "Alg 17: w out of range")]
fn test_bit_pack_w_range() {
let mut w = R0; w.0[0] = 10;
let mut bytes = [0u8; 32];
bit_pack(&w, 2, 5, &mut bytes); }
#[test]
#[should_panic(expected = "Alg 18: b out of range")]
fn test_simple_bit_unpack_b_range() {
let bytes = [0u8; 32];
let _unused = simple_bit_unpack(&bytes, 0); }
#[test]
#[should_panic(expected = "Alg 18: bad output size")]
fn test_simple_bit_unpack_input_size() {
let bytes = [0u8; 65]; let _unused = simple_bit_unpack(&bytes, 2);
}
#[test]
#[should_panic(expected = "Alg 20: omega+K out of range")]
fn test_hint_bit_pack_omega_k_range() {
const K: usize = 255;
let h = [R0; K];
let mut y_bytes = [0u8; 256];
hint_bit_pack::<false, K>(2, &h, &mut y_bytes); }
#[test]
#[should_panic(expected = "Alg 20: h not 0/1")]
fn test_hint_bit_pack_h_range() {
const K: usize = 2;
let mut h = [R0; K];
h[0].0[0] = 2; let mut y_bytes = [0u8; 4];
hint_bit_pack::<false, K>(2, &h, &mut y_bytes);
}
#[test]
#[should_panic(expected = "Alg 21: omega+K too large")]
fn test_hint_bit_unpack_omega_k_range() {
const K: usize = 255;
let y_bytes = [0u8; 256];
let _unused = hint_bit_unpack::<K>(2, &y_bytes); }
}