use crate::csprng::Csprng;
#[must_use]
pub fn pixel_expansion(n: usize) -> usize {
assert!(n >= 2, "n must be at least 2");
1usize << (n - 1)
}
fn basis_matrices(n: usize) -> (Vec<Vec<bool>>, Vec<Vec<bool>>) {
let m = pixel_expansion(n);
let mut even_subsets: Vec<u64> = Vec::new();
let mut odd_subsets: Vec<u64> = Vec::new();
for mask in 0..(1u64 << n) {
if mask.count_ones() % 2 == 0 {
even_subsets.push(mask);
} else {
odd_subsets.push(mask);
}
}
debug_assert_eq!(even_subsets.len(), m);
debug_assert_eq!(odd_subsets.len(), m);
let mut c0: Vec<Vec<bool>> = vec![vec![false; m]; n];
let mut c1: Vec<Vec<bool>> = vec![vec![false; m]; n];
#[allow(clippy::needless_range_loop)]
for (col, mask) in even_subsets.iter().enumerate() {
for i in 0..n {
c0[i][col] = mask & (1u64 << i) != 0;
}
}
#[allow(clippy::needless_range_loop)]
for (col, mask) in odd_subsets.iter().enumerate() {
for i in 0..n {
c1[i][col] = mask & (1u64 << i) != 0;
}
}
(c0, c1)
}
fn random_permutation<R: Csprng>(rng: &mut R, m: usize) -> Vec<usize> {
let mut perm: Vec<usize> = (0..m).collect();
for i in (1..m).rev() {
let j = bounded_index(rng, i + 1);
perm.swap(i, j);
}
perm
}
fn bounded_index<R: Csprng>(rng: &mut R, bound: usize) -> usize {
assert!(bound > 0);
if bound == 1 {
return 0;
}
let bits = (usize::BITS - (bound - 1).leading_zeros()) as usize; let bytes = bits.div_ceil(8);
let excess = bytes * 8 - bits;
let top_mask: u8 = if excess == 0 { 0xFF } else { 0xFFu8 >> excess };
let mut buf = vec![0u8; bytes];
loop {
rng.fill_bytes(&mut buf);
buf[0] &= top_mask;
let mut acc = 0usize;
for &b in &buf {
acc = (acc << 8) | (b as usize);
}
if acc < bound {
return acc;
}
}
}
#[must_use]
pub fn split_n_of_n<R: Csprng>(
rng: &mut R,
secret: &[Vec<bool>],
n: usize,
) -> Vec<Vec<Vec<bool>>> {
assert!(n >= 2, "n must be at least 2");
assert!(!secret.is_empty(), "secret image must not be empty");
let h = secret.len();
let w = secret[0].len();
assert!(w > 0, "secret rows must be non-empty");
for row in secret {
assert_eq!(row.len(), w, "secret rows must be equal length");
}
let m = pixel_expansion(n);
let (c0, c1) = basis_matrices(n);
let mut shares: Vec<Vec<Vec<bool>>> = (0..n)
.map(|_| (0..h).map(|_| Vec::with_capacity(w * m)).collect())
.collect();
#[allow(clippy::needless_range_loop)]
for y in 0..h {
for x in 0..w {
let basis = if secret[y][x] { &c1 } else { &c0 };
let perm = random_permutation(rng, m);
for i in 0..n {
for &col in &perm {
shares[i][y].push(basis[i][col]);
}
}
}
}
shares
}
#[must_use]
pub fn stack(shares: &[Vec<Vec<bool>>]) -> Option<Vec<Vec<bool>>> {
if shares.is_empty() {
return None;
}
let h = shares[0].len();
let w = if h > 0 { shares[0][0].len() } else { 0 };
for s in shares {
if s.len() != h {
return None;
}
for row in s {
if row.len() != w {
return None;
}
}
}
let mut out = vec![vec![false; w]; h];
for s in shares {
for y in 0..h {
for x in 0..w {
out[y][x] |= s[y][x];
}
}
}
Some(out)
}
#[must_use]
pub fn decode(stacked: &[Vec<bool>], n: usize) -> Option<Vec<Vec<bool>>> {
assert!(n >= 2, "n must be at least 2");
let m = pixel_expansion(n);
let h = stacked.len();
if h == 0 {
return Some(Vec::new());
}
let total_w = stacked[0].len();
if !total_w.is_multiple_of(m) {
return None;
}
let w = total_w / m;
let mut out = vec![vec![false; w]; h];
for y in 0..h {
if stacked[y].len() != total_w {
return None;
}
for x in 0..w {
let block = &stacked[y][x * m..(x + 1) * m];
let weight = block.iter().filter(|&&b| b).count();
if weight == m {
out[y][x] = true;
} else if weight + 1 == m {
out[y][x] = false;
} else {
return None;
}
}
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0x76u8; 32])
}
fn checker(h: usize, w: usize) -> Vec<Vec<bool>> {
(0..h)
.map(|y| (0..w).map(|x| (x + y) % 2 == 0).collect())
.collect()
}
#[test]
fn pixel_expansion_table() {
assert_eq!(pixel_expansion(2), 2);
assert_eq!(pixel_expansion(3), 4);
assert_eq!(pixel_expansion(4), 8);
}
#[test]
fn basis_matrices_have_expected_shape_n2() {
let (c0, c1) = basis_matrices(2);
assert_eq!(c0[0], vec![false, true]);
assert_eq!(c0[1], vec![false, true]);
assert_eq!(c1[0], vec![true, false]);
assert_eq!(c1[1], vec![false, true]);
}
#[test]
fn round_trip_2_of_2() {
let mut r = rng();
let secret = checker(4, 4);
let shares = split_n_of_n(&mut r, &secret, 2);
assert_eq!(shares.len(), 2);
for s in &shares {
assert_eq!(s.len(), 4);
for row in s {
assert_eq!(row.len(), 8);
}
}
let stacked = stack(&shares).unwrap();
let decoded = decode(&stacked, 2).unwrap();
assert_eq!(decoded, secret);
}
#[test]
fn round_trip_3_of_3() {
let mut r = rng();
let secret = checker(3, 5);
let shares = split_n_of_n(&mut r, &secret, 3);
assert_eq!(shares.len(), 3);
let m = pixel_expansion(3);
for s in &shares {
assert_eq!(s.len(), 3);
for row in s {
assert_eq!(row.len(), 5 * m);
}
}
let stacked = stack(&shares).unwrap();
let decoded = decode(&stacked, 3).unwrap();
assert_eq!(decoded, secret);
}
#[test]
fn round_trip_4_of_4() {
let mut r = rng();
let secret = checker(2, 6);
let shares = split_n_of_n(&mut r, &secret, 4);
let stacked = stack(&shares).unwrap();
let decoded = decode(&stacked, 4).unwrap();
assert_eq!(decoded, secret);
}
#[test]
fn fewer_than_n_shares_are_indistinguishable_from_random_at_white_pixel() {
let mut r = rng();
let n = 3;
let secret = vec![vec![false, true]];
let shares = split_n_of_n(&mut r, &secret, n);
let partial = stack(&shares[..2]).unwrap();
let m = pixel_expansion(n);
let weight_white: usize = partial[0][0..m].iter().filter(|&&b| b).count();
let weight_black: usize = partial[0][m..2 * m].iter().filter(|&&b| b).count();
assert_eq!(weight_white, weight_black);
}
#[test]
fn share_alone_carries_no_distinguishing_information() {
let mut r = rng();
let n = 3;
let m = pixel_expansion(n);
let secret = vec![vec![false, true]];
let shares = split_n_of_n(&mut r, &secret, n);
let weight_white: usize = shares[0][0][0..m].iter().filter(|&&b| b).count();
let weight_black: usize = shares[0][0][m..2 * m].iter().filter(|&&b| b).count();
assert_eq!(weight_white, weight_black);
}
#[test]
fn decode_rejects_malformed_block() {
let n = 2;
let m = pixel_expansion(n);
let mut bad = vec![vec![false; m]];
bad[0][0] = false; assert!(decode(&bad, n).is_none());
}
#[test]
fn stack_rejects_mismatched_dimensions() {
let a: Vec<Vec<bool>> = vec![vec![false, true]];
let b: Vec<Vec<bool>> = vec![vec![false, true, false]];
assert!(stack(&[a, b]).is_none());
}
}