use crate::primes::{gcd, mod_inverse, random_below};
use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::secure::ct_eq_biguint;
#[derive(Clone, Debug)]
pub struct AsmuthBloomParams {
m0: BigUint,
moduli: Vec<BigUint>,
k: usize,
m_bot: BigUint,
a_range: BigUint,
pair_inv: Option<Vec<Vec<BigUint>>>,
}
#[derive(Clone, Eq, PartialEq)]
pub struct Share {
pub index: usize,
pub residue: BigUint,
}
impl core::fmt::Debug for Share {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Share(<elided>)")
}
}
impl AsmuthBloomParams {
#[must_use]
pub fn new(m0: BigUint, moduli: Vec<BigUint>, k: usize) -> Option<Self> {
let n = moduli.len();
if k < 2 || k > n {
return None;
}
if m0 <= BigUint::one() {
return None;
}
for i in 1..n {
if moduli[i - 1] >= moduli[i] {
return None;
}
}
for m in &moduli {
if m <= &m0 {
return None;
}
if gcd(m, &m0) != BigUint::one() {
return None;
}
}
for i in 0..n {
for j in (i + 1)..n {
if gcd(&moduli[i], &moduli[j]) != BigUint::one() {
return None;
}
}
}
let m_bot = product(&moduli[..k]);
let m_top = product(&moduli[n - (k - 1)..]);
if m0.mul_ref(&m_top) >= m_bot {
return None;
}
let (a_range, _) = m_bot.div_rem(&m0);
let max_bits = moduli.iter().map(BigUint::bits).max().unwrap_or(0);
let pair_inv = if max_bits >= crate::mignotte::CRT_PRECOMP_THRESHOLD_BITS {
let mut table: Vec<Vec<BigUint>> = Vec::with_capacity(n);
for i in 0..n {
let mut row = Vec::with_capacity(n);
for j in 0..n {
if i == j {
row.push(BigUint::zero());
} else {
let m_j_mod_m_i = moduli[j].modulo(&moduli[i]);
let inv = mod_inverse(&m_j_mod_m_i, &moduli[i])?;
row.push(inv);
}
}
table.push(row);
}
Some(table)
} else {
None
};
Some(Self {
m0,
moduli,
k,
m_bot,
a_range,
pair_inv,
})
}
#[must_use]
pub fn k(&self) -> usize {
self.k
}
#[must_use]
pub fn n(&self) -> usize {
self.moduli.len()
}
#[must_use]
pub fn m0(&self) -> &BigUint {
&self.m0
}
#[must_use]
pub fn moduli(&self) -> &[BigUint] {
&self.moduli
}
}
#[must_use]
pub fn split<R: Csprng>(params: &AsmuthBloomParams, rng: &mut R, secret: &BigUint) -> Vec<Share> {
assert!(secret < ¶ms.m0, "secret must be < m_0");
let a = random_below(rng, ¶ms.a_range).expect("a_range > 0 by construction");
let y = secret.add_ref(&a.mul_ref(¶ms.m0));
params
.moduli
.iter()
.enumerate()
.map(|(i, m)| Share {
index: i + 1,
residue: y.modulo(m),
})
.collect()
}
#[must_use]
pub fn reconstruct(params: &AsmuthBloomParams, shares: &[Share]) -> Option<BigUint> {
let k = params.k;
if shares.len() < k {
return None;
}
for s in shares {
if s.index == 0 || s.index > params.n() {
return None;
}
if s.residue >= params.moduli[s.index - 1] {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].index == shares[j].index {
return None;
}
}
}
let used = &shares[..k];
let (mut y, mut prod) = (BigUint::zero(), BigUint::one());
let mut first = true;
let mut folded_indices: Vec<usize> = Vec::with_capacity(k);
for s in used {
let m_i_idx = s.index - 1;
let m = ¶ms.moduli[m_i_idx];
if first {
y = s.residue.clone();
prod = m.clone();
folded_indices.push(m_i_idx);
first = false;
continue;
}
let inv = if let Some(pair_inv) = ¶ms.pair_inv {
let mut acc = BigUint::one();
for &j in &folded_indices {
debug_assert!(
j != m_i_idx,
"fold step would self-multiply pair_inv diagonal",
);
acc = BigUint::mod_mul(&acc, &pair_inv[m_i_idx][j], m);
}
acc
} else {
mod_inverse(&prod.modulo(m), m)?
};
let y_mod_m = y.modulo(m);
let diff = if s.residue >= y_mod_m {
s.residue.sub_ref(&y_mod_m)
} else {
s.residue.add_ref(m).sub_ref(&y_mod_m)
};
let t = BigUint::mod_mul(&diff, &inv, m);
y = y.add_ref(&prod.mul_ref(&t));
prod = prod.mul_ref(m);
folded_indices.push(m_i_idx);
}
let y = y.modulo(&prod);
if y >= params.m_bot {
return None;
}
for s in &shares[k..] {
let m = ¶ms.moduli[s.index - 1];
let pred = y.modulo(m);
if !ct_eq_biguint(&pred, &s.residue) {
return None;
}
}
Some(y.modulo(¶ms.m0))
}
fn product(values: &[BigUint]) -> BigUint {
let mut acc = BigUint::one();
for v in values {
acc = acc.mul_ref(v);
}
acc
}
#[must_use]
pub fn small_example_3_of_5() -> AsmuthBloomParams {
let m0 = BigUint::from_u64(5);
let moduli = [11u64, 13, 17, 19, 23]
.into_iter()
.map(BigUint::from_u64)
.collect();
AsmuthBloomParams::new(m0, moduli, 3).expect("hand-validated small parameter set")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[91u8; 32])
}
#[test]
fn small_round_trip() {
let p = small_example_3_of_5();
let mut r = rng();
for s_val in 0..5u64 {
let secret = BigUint::from_u64(s_val);
let shares = split(&p, &mut r, &secret);
assert_eq!(shares.len(), 5);
assert_eq!(reconstruct(&p, &shares[..3]), Some(secret.clone()));
assert_eq!(reconstruct(&p, &shares[1..4]), Some(secret.clone()));
assert_eq!(reconstruct(&p, &shares[2..]), Some(secret));
}
}
#[test]
fn extras_validated() {
let p = small_example_3_of_5();
let mut r = rng();
let secret = BigUint::from_u64(3);
let shares = split(&p, &mut r, &secret);
assert_eq!(reconstruct(&p, &shares), Some(secret));
}
#[test]
fn tampered_extra_rejected() {
let p = small_example_3_of_5();
let mut r = rng();
let secret = BigUint::from_u64(2);
let mut shares = split(&p, &mut r, &secret);
shares[4].residue = shares[4].residue.add_ref(&BigUint::one()).modulo(&p.moduli[4]);
assert!(reconstruct(&p, &shares).is_none());
}
#[test]
fn below_threshold_returns_none() {
let p = small_example_3_of_5();
let mut r = rng();
let secret = BigUint::from_u64(1);
let shares = split(&p, &mut r, &secret);
assert!(reconstruct(&p, &shares[..2]).is_none());
}
#[test]
fn duplicate_share_rejected() {
let p = small_example_3_of_5();
let mut r = rng();
let secret = BigUint::from_u64(4);
let mut shares = split(&p, &mut r, &secret);
shares[1] = shares[0].clone();
assert!(reconstruct(&p, &shares[..3]).is_none());
}
#[test]
#[should_panic(expected = "secret must be < m_0")]
fn secret_above_m0_panics() {
let p = small_example_3_of_5();
let mut r = rng();
let _ = split(&p, &mut r, &BigUint::from_u64(5));
}
#[test]
fn rejects_non_coprime_with_m0() {
let m0 = BigUint::from_u64(5);
let m = vec![
BigUint::from_u64(11),
BigUint::from_u64(13),
BigUint::from_u64(17),
BigUint::from_u64(19),
BigUint::from_u64(25),
];
assert!(AsmuthBloomParams::new(m0, m, 3).is_none());
}
#[test]
fn rejects_when_inequality_fails() {
let m0 = BigUint::from_u64(5);
let m = vec![
BigUint::from_u64(7),
BigUint::from_u64(11),
BigUint::from_u64(13),
BigUint::from_u64(17),
BigUint::from_u64(19),
];
assert!(AsmuthBloomParams::new(m0, m, 3).is_none());
}
#[test]
fn first_k_tamper_likely_falls_outside_m_bot() {
let p = small_example_3_of_5();
let mut r = rng();
let secret = BigUint::from_u64(2);
let shares = split(&p, &mut r, &secret);
let mut any_disagreement = false;
for delta in 1..7u64 {
let mut bad = shares.clone();
bad[0].residue =
bad[0]
.residue
.add_ref(&BigUint::from_u64(delta))
.modulo(&p.moduli[0]);
if let Some(got) = reconstruct(&p, &bad[..3]) {
if got != secret {
any_disagreement = true;
}
} else {
any_disagreement = true;
}
}
assert!(
any_disagreement,
"tampered first-k must not silently roundtrip to the original secret"
);
}
#[test]
fn larger_parameter_round_trip() {
let m0 = BigUint::from_u64(11);
let moduli = [101u64, 103, 107, 109, 113, 127, 131]
.into_iter()
.map(BigUint::from_u64)
.collect();
let params = AsmuthBloomParams::new(m0, moduli, 4).expect("valid (4,7) params");
let mut r = rng();
for s_val in 0..11u64 {
let secret = BigUint::from_u64(s_val);
let shares = split(¶ms, &mut r, &secret);
assert_eq!(reconstruct(¶ms, &shares[..4]), Some(secret.clone()));
assert_eq!(reconstruct(¶ms, &shares[3..7]), Some(secret));
}
}
}