use crate::conversion::{
bit_pack, bit_unpack, hint_bit_pack, hint_bit_unpack, simple_bit_pack, simple_bit_unpack,
};
use crate::helpers::{bit_length, is_in_range};
use crate::types::{R, R0};
use crate::{D, Q};
pub(crate) fn pk_encode<const K: usize, const PK_LEN: usize>(
rho: &[u8; 32], t1: &[R; K],
) -> [u8; PK_LEN] {
const BLQD: usize = bit_length(Q - 1) - D as usize;
debug_assert!(t1.iter().all(|t| is_in_range(t, 0, (1 << BLQD) - 1)), "Alg 22: t1 out of range");
debug_assert_eq!(PK_LEN, 32 + 32 * K * BLQD, "Alg 22: bad pk/config size");
let mut pk = [0u8; PK_LEN];
pk[0..32].copy_from_slice(rho);
pk[32..]
.chunks_mut(32 * BLQD)
.enumerate()
.take(K) .for_each(|(i, chunk)| simple_bit_pack(&t1[i], (1 << BLQD) - 1, chunk));
pk
}
pub(crate) fn pk_decode<const K: usize, const PK_LEN: usize>(
pk: &[u8; PK_LEN],
) -> Result<(&[u8; 32], [R; K]), &'static str> {
const BLQD: usize = bit_length(Q - 1) - D as usize;
debug_assert_eq!(pk.len(), 32 + 32 * K * BLQD, "Alg 23: incorrect pk length");
debug_assert_eq!(PK_LEN, 32 + 32 * K * BLQD, "Alg 23: bad pk/config size");
let rho = <&[u8; 32]>::try_from(&pk[0..32]).expect("Alg 23: try_from fail");
let mut t1 = [R0; K]; for i in 0..K {
t1[i] =
simple_bit_unpack(&pk[32 + 32 * i * BLQD..32 + 32 * (i + 1) * BLQD], (1 << BLQD) - 1)?;
}
debug_assert!(t1.iter().all(|t| is_in_range(t, 0, (1 << BLQD) - 1)), "Alg 23: t1 out of range");
Ok((rho, t1))
}
pub(crate) fn sk_encode<const K: usize, const L: usize, const SK_LEN: usize>(
eta: i32, rho: &[u8; 32], k: &[u8; 32], tr: &[u8; 64], s_1: &[R; L], s_2: &[R; K], t_0: &[R; K],
) -> [u8; SK_LEN] {
let top = 1 << (D - 1);
debug_assert!((eta == 2) || (eta == 4), "Alg 24: incorrect eta");
debug_assert!(s_1.iter().all(|x| is_in_range(x, eta, eta)), "Alg 24: s1 out of range");
debug_assert!(s_2.iter().all(|x| is_in_range(x, eta, eta)), "Alg 24: s2 out of range");
debug_assert!(t_0.iter().all(|x| is_in_range(x, top - 1, top)), "Alg 24: t0 out of range");
debug_assert_eq!(
SK_LEN,
128 + 32 * ((K + L) * bit_length(2 * eta) + D as usize * K),
"Alg 24: bad sk/config size"
);
let mut sk = [0u8; SK_LEN];
sk[0..32].copy_from_slice(rho);
sk[32..64].copy_from_slice(k);
sk[64..128].copy_from_slice(tr);
let start = 128;
let step = 32 * bit_length(2 * eta);
for i in 0..L {
bit_pack(&s_1[i], eta, eta, &mut sk[start + i * step..start + (i + 1) * step]);
}
let start = start + L * step;
for i in 0..K {
bit_pack(&s_2[i], eta, eta, &mut sk[start + i * step..start + (i + 1) * step]);
}
let start = start + K * step;
let step = 32 * D as usize;
for i in 0..K {
bit_pack(&t_0[i], top - 1, top, &mut sk[start + i * step..start + (i + 1) * step]);
}
debug_assert_eq!(start + K * step, sk.len(), "Alg 24: length miscalc");
sk
}
#[allow(clippy::similar_names, clippy::type_complexity)]
pub(crate) fn sk_decode<const K: usize, const L: usize, const SK_LEN: usize>(
eta: i32, sk: &[u8; SK_LEN],
) -> Result<(&[u8; 32], &[u8; 32], &[u8; 64], [R; L], [R; K], [R; K]), &'static str> {
const TOP: i32 = 1 << (D - 1);
debug_assert!((eta == 2) || (eta == 4), "Alg 25: incorrect eta");
debug_assert_eq!(
SK_LEN,
128 + 32 * ((K + L) * bit_length(2 * eta) + D as usize * K),
"Alg 25: bad sk/config size"
);
let (mut s_1, mut s_2, mut t_0) = ([R0; L], [R0; K], [R0; K]);
let rho = <&[u8; 32]>::try_from(&sk[0..32]).expect("Alg 25: try_from1 fail");
let k = <&[u8; 32]>::try_from(&sk[32..64]).expect("Alg 25: try_from2 fail");
let tr = <&[u8; 64]>::try_from(&sk[64..128]).expect("Alg 25: try_from3 fail");
let start = 128;
let step = 32 * bit_length(2 * eta);
for i in 0..L {
s_1[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], eta, eta)?;
}
let start = start + L * step;
for i in 0..K {
s_2[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], eta, eta)?;
}
let start = start + K * step;
let step = 32 * D as usize;
for i in 0..K {
t_0[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], TOP - 1, TOP)?;
}
debug_assert_eq!(start + K * step, sk.len(), "Alg 25: length miscalc");
Ok((rho, k, tr, s_1, s_2, t_0))
}
pub(crate) fn sig_encode<
const CTEST: bool,
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
const SIG_LEN: usize,
>(
gamma1: i32, omega: i32, c_tilde: &[u8; LAMBDA_DIV4], z: &[R; L], h: &[R; K],
) -> [u8; SIG_LEN] {
debug_assert!(z.iter().all(|x| is_in_range(x, gamma1 - 1, gamma1)), "Alg 26: z out of range");
debug_assert!(h.iter().all(|x| is_in_range(x, 0, 1)), "Alg 26: h out of range");
debug_assert_eq!(
SIG_LEN,
LAMBDA_DIV4 + L * 32 * (1 + bit_length(gamma1 - 1)) + omega.unsigned_abs() as usize + K,
"Alg 26: bad sig/config size"
);
let mut sigma = [0u8; SIG_LEN];
sigma[..LAMBDA_DIV4].copy_from_slice(c_tilde);
let start = LAMBDA_DIV4;
let step = 32 * (1 + bit_length(gamma1 - 1));
for i in 0..L {
bit_pack(&z[i], gamma1 - 1, gamma1, &mut sigma[start + i * step..start + (i + 1) * step]);
}
hint_bit_pack::<CTEST, K>(omega, h, &mut sigma[start + L * step..]);
sigma
}
#[allow(clippy::type_complexity)]
pub(crate) fn sig_decode<
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
const SIG_LEN: usize,
>(
gamma1: i32, omega: i32, sigma: &[u8; SIG_LEN],
) -> Result<([u8; LAMBDA_DIV4], [R; L], Option<[R; K]>), &'static str> {
debug_assert_eq!(
SIG_LEN,
LAMBDA_DIV4 + L * 32 * (1 + bit_length(gamma1 - 1)) + omega.unsigned_abs() as usize + K,
"Alg 27: bad sig/config size"
);
let mut c_tilde = [0u8; LAMBDA_DIV4];
let mut z: [R; L] = [R0; L];
c_tilde[0..LAMBDA_DIV4].copy_from_slice(&sigma[0..LAMBDA_DIV4]);
let start = LAMBDA_DIV4;
let step = 32 * (bit_length(gamma1 - 1) + 1);
for i in 0..L {
z[i] = bit_unpack(&sigma[start + i * step..start + (i + 1) * step], gamma1 - 1, gamma1)?;
}
let h = hint_bit_unpack::<K>(omega, &sigma[start + L * step..])?;
Ok((c_tilde, z, Some(h)))
}
pub(crate) fn w1_encode<const K: usize>(gamma2: i32, w1: &[R; K], w1_tilde: &mut [u8]) {
let qm1_d_2g_m1 = (Q - 1) / (2 * gamma2) - 1;
debug_assert_eq!(
w1_tilde.len(),
32 * K * bit_length(qm1_d_2g_m1),
"Alg 28: bad w1_tilde/config size"
);
debug_assert!(w1.iter().all(|r| is_in_range(r, 0, qm1_d_2g_m1)), "Alg 28: w1 out of range");
let step = 32 * bit_length(qm1_d_2g_m1);
for i in 0..K {
simple_bit_pack(&w1[i], qm1_d_2g_m1, &mut w1_tilde[i * step..(i + 1) * step]);
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::RngCore;
#[test]
fn test_pk_encode_decode_roundtrip1() {
let mut random_pk = [0u8; 1312];
random_pk.iter_mut().for_each(|a| *a = rand::random::<u8>());
let (rho, t1) = pk_decode::<4, 1312>(&random_pk).unwrap();
let res = pk_encode::<4, 1312>(rho, &t1);
assert_eq!(&random_pk[..], res);
}
#[test]
fn test_pk_encode_decode_roundtrip2() {
let mut random_pk = [0u8; 1952];
random_pk.iter_mut().for_each(|a| *a = rand::random::<u8>());
let (rho, t1) = pk_decode::<6, 1952>(&random_pk).unwrap();
let res = pk_encode::<6, 1952>(rho, &t1);
assert_eq!(random_pk, res);
}
#[test]
fn test_pk_encode_decode_roundtrip3() {
let mut random_pk = [0u8; 2592];
random_pk.iter_mut().for_each(|a| *a = rand::random::<u8>());
let (rho, t1) = pk_decode::<8, 2592>(&random_pk).unwrap();
let res = pk_encode::<8, 2592>(rho, &t1);
assert_eq!(random_pk, res);
}
fn get_vec(max: u32) -> R {
let mut rnd_r = R0; rnd_r
.0
.iter_mut()
.for_each(|e| *e = rand::random::<i32>().rem_euclid(i32::try_from(max).unwrap()));
rnd_r
}
#[test]
#[allow(clippy::similar_names)]
fn test_sk_encode_decode_roundtrip1() {
let (rho, k) = (rand::random::<[u8; 32]>(), rand::random::<[u8; 32]>());
let mut tr = [0u8; 64];
tr.iter_mut().for_each(|e| *e = rand::random::<u8>());
let s1 = [get_vec(2), get_vec(2), get_vec(2), get_vec(2)];
let s2 = [get_vec(2), get_vec(2), get_vec(2), get_vec(2)];
let t0 = [
get_vec(1 << 11),
get_vec(1 << 11),
get_vec(1 << 11),
get_vec(1 << 11),
];
let sk = sk_encode::<4, 4, 2560>(2, &rho, &k, &tr, &s1, &s2, &t0);
let res = sk_decode::<4, 4, 2560>(2, &sk);
assert!(res.is_ok());
let (rho_test, k_test, tr_test, s1_test, s2_test, t0_test) = res.unwrap();
assert!(
(rho == *rho_test)
&& (k == *k_test)
&& (tr == *tr_test)
&& (s1.iter().zip(s1_test.iter()).all(|(a, b)| a.0 == b.0))
&& (s2.iter().zip(s2_test.iter()).all(|(a, b)| a.0 == b.0))
&& (t0.iter().zip(t0_test.iter()).all(|(a, b)| a.0 == b.0))
);
}
#[test]
fn test_sig_roundtrip() {
let mut c_tilde = [0u8; 2 * 128 / 8];
rand::thread_rng().fill_bytes(&mut c_tilde);
let z = [get_vec(2), get_vec(2), get_vec(2), get_vec(2)];
let h = [get_vec(1), get_vec(1), get_vec(1), get_vec(1)];
let sigma =
sig_encode::<false, 4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &c_tilde.clone(), &z, &h);
let (c_test, z_test, h_test) =
sig_decode::<4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &sigma).unwrap();
assert_eq!(c_tilde[0..8], c_test[0..8]);
assert!(z.iter().zip(z_test.iter()).all(|(a, b)| a.0 == b.0));
assert!(h.iter().zip(h_test.unwrap().iter()).all(|(a, b)| a.0 == b.0));
}
}