use crate::field::PrimeField;
use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::secure::ct_eq_biguint;
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Formula {
Party(usize),
And(Vec<Formula>),
Or(Vec<Formula>),
}
impl Formula {
#[must_use]
pub fn party(j: usize) -> Self {
Formula::Party(j)
}
#[must_use]
pub fn and(children: Vec<Formula>) -> Self {
Formula::And(children)
}
#[must_use]
pub fn or(children: Vec<Formula>) -> Self {
Formula::Or(children)
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct ShareFragment {
pub path: Vec<u32>,
pub value: BigUint,
}
impl core::fmt::Debug for ShareFragment {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("ShareFragment(<elided>)")
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct PlayerShare {
pub player: usize,
pub fragments: Vec<ShareFragment>,
}
impl core::fmt::Debug for PlayerShare {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("PlayerShare(<elided>)")
}
}
#[must_use]
pub fn split<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secret: &BigUint,
formula: &Formula,
) -> Vec<PlayerShare> {
let mut out: Vec<(usize, ShareFragment)> = Vec::new();
distribute(field, rng, formula, &field.reduce(secret), &mut Vec::new(), &mut out);
let mut grouped: std::collections::BTreeMap<usize, Vec<ShareFragment>> =
std::collections::BTreeMap::new();
for (p, frag) in out {
grouped.entry(p).or_default().push(frag);
}
grouped
.into_iter()
.map(|(player, fragments)| PlayerShare { player, fragments })
.collect()
}
fn leaf_party_at_path(formula: &Formula, path: &[u32]) -> Option<usize> {
let mut node = formula;
for &step in path {
match node {
Formula::Party(_) => return None, Formula::And(children) | Formula::Or(children) => {
node = children.get(step as usize)?;
}
}
}
match node {
Formula::Party(p) => Some(*p),
_ => None, }
}
fn distribute<R: Csprng>(
field: &PrimeField,
rng: &mut R,
node: &Formula,
value: &BigUint,
path: &mut Vec<u32>,
out: &mut Vec<(usize, ShareFragment)>,
) {
match node {
Formula::Party(j) => {
assert!(*j != 0, "player identifiers are 1-based");
out.push((
*j,
ShareFragment {
path: path.clone(),
value: value.clone(),
},
));
}
Formula::And(children) => {
assert!(!children.is_empty(), "AND node must have children");
let m = children.len();
let mut sum = BigUint::zero();
let mut pieces: Vec<BigUint> = Vec::with_capacity(m);
for _ in 0..(m - 1) {
let v = field.random(rng);
sum = field.add(&sum, &v);
pieces.push(v);
}
pieces.push(field.sub(value, &sum));
for (j, (child, piece)) in children.iter().zip(pieces.iter()).enumerate() {
path.push(j as u32);
distribute(field, rng, child, piece, path, out);
path.pop();
}
}
Formula::Or(children) => {
assert!(!children.is_empty(), "OR node must have children");
for (j, child) in children.iter().enumerate() {
path.push(j as u32);
distribute(field, rng, child, value, path, out);
path.pop();
}
}
}
}
#[must_use]
pub fn reconstruct(
field: &PrimeField,
formula: &Formula,
shares: &[PlayerShare],
) -> Option<BigUint> {
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].player == shares[j].player {
return None;
}
}
}
for s in shares {
for f in &s.fragments {
match leaf_party_at_path(formula, &f.path) {
Some(p) if p == s.player => {}
_ => return None,
}
}
}
let mut by_path: std::collections::HashMap<Vec<u32>, BigUint> =
std::collections::HashMap::new();
for s in shares {
for f in &s.fragments {
if let Some(prev) = by_path.get(&f.path) {
if !ct_eq_biguint(prev, &f.value) {
return None;
}
} else {
by_path.insert(f.path.clone(), f.value.clone());
}
}
}
recover(field, formula, &mut Vec::new(), &by_path)
}
fn recover(
field: &PrimeField,
node: &Formula,
path: &mut Vec<u32>,
by_path: &std::collections::HashMap<Vec<u32>, BigUint>,
) -> Option<BigUint> {
match node {
Formula::Party(_) => by_path.get(path).cloned(),
Formula::And(children) => {
let mut sum = BigUint::zero();
for (j, child) in children.iter().enumerate() {
path.push(j as u32);
let part = recover(field, child, path, by_path);
path.pop();
sum = field.add(&sum, &part?);
}
Some(sum)
}
Formula::Or(children) => {
for (j, child) in children.iter().enumerate() {
path.push(j as u32);
let part = recover(field, child, path, by_path);
path.pop();
if let Some(v) = part {
return Some(v);
}
}
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0x77u8; 32])
}
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64(65_537))
}
fn pick(shares: &[PlayerShare], players: &[usize]) -> Vec<PlayerShare> {
shares
.iter()
.filter(|s| players.contains(&s.player))
.cloned()
.collect()
}
#[test]
fn two_of_three_via_formula() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![
Formula::and(vec![Formula::party(1), Formula::party(2)]),
Formula::and(vec![Formula::party(1), Formula::party(3)]),
Formula::and(vec![Formula::party(2), Formula::party(3)]),
]);
let secret = BigUint::from_u64(0xBEEF);
let shares = split(&f, &mut r, &secret, &formula);
for s in &shares {
assert_eq!(s.fragments.len(), 2, "player {} fragment count", s.player);
}
for &(a, b) in &[(1usize, 2usize), (1, 3), (2, 3)] {
let coalition = pick(&shares, &[a, b]);
assert_eq!(
reconstruct(&f, &formula, &coalition),
Some(secret.clone()),
"qualified pair ({a},{b})"
);
}
for j in 1..=3 {
let solo = pick(&shares, &[j]);
assert!(reconstruct(&f, &formula, &solo).is_none(), "singleton {j}");
}
}
#[test]
fn nested_formula() {
let f = small();
let mut r = rng();
let formula = Formula::and(vec![
Formula::party(1),
Formula::or(vec![
Formula::party(2),
Formula::and(vec![Formula::party(3), Formula::party(4)]),
]),
]);
let secret = BigUint::from_u64(0xC0DE);
let shares = split(&f, &mut r, &secret, &formula);
for q in &[vec![1, 2], vec![1, 3, 4], vec![1, 2, 3, 4]] {
let c = pick(&shares, q);
assert_eq!(reconstruct(&f, &formula, &c), Some(secret.clone()), "qualifies {q:?}");
}
for q in &[vec![1usize], vec![2, 3, 4], vec![1, 3], vec![1, 4]] {
let c = pick(&shares, q);
assert!(reconstruct(&f, &formula, &c).is_none(), "forbidden {q:?}");
}
}
#[test]
fn or_root_replicates() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![Formula::party(1), Formula::party(2)]);
let secret = BigUint::from_u64(7);
let shares = split(&f, &mut r, &secret, &formula);
for s in &shares {
assert_eq!(s.fragments.len(), 1);
assert_eq!(s.fragments[0].value, secret);
}
let solo = pick(&shares, &[1]);
assert_eq!(reconstruct(&f, &formula, &solo), Some(secret));
}
#[test]
fn and_root_requires_all() {
let f = small();
let mut r = rng();
let formula = Formula::and(vec![
Formula::party(1),
Formula::party(2),
Formula::party(3),
]);
let secret = BigUint::from_u64(0x1234);
let shares = split(&f, &mut r, &secret, &formula);
for s in &shares {
assert_eq!(s.fragments.len(), 1);
}
let sum = shares
.iter()
.map(|s| s.fragments[0].value.clone())
.fold(BigUint::zero(), |a, b| f.add(&a, &b));
assert_eq!(sum, secret);
assert_eq!(reconstruct(&f, &formula, &shares), Some(secret));
for &(a, b) in &[(1usize, 2usize), (1, 3), (2, 3)] {
let pair = pick(&shares, &[a, b]);
assert!(reconstruct(&f, &formula, &pair).is_none());
}
}
#[test]
fn duplicate_player_rejected() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![
Formula::and(vec![Formula::party(1), Formula::party(2)]),
Formula::and(vec![Formula::party(1), Formula::party(3)]),
]);
let secret = BigUint::from_u64(99);
let shares = split(&f, &mut r, &secret, &formula);
let dup = vec![shares[0].clone(), shares[0].clone()];
assert!(reconstruct(&f, &formula, &dup).is_none());
}
#[test]
fn or_tamper_returns_wrong_value_first_branch_wins() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![
Formula::and(vec![Formula::party(1), Formula::party(2)]),
Formula::and(vec![Formula::party(1), Formula::party(3)]),
]);
let secret = BigUint::from_u64(33);
let mut shares = split(&f, &mut r, &secret, &formula);
let p1 = shares.iter_mut().find(|s| s.player == 1).unwrap();
p1.fragments.sort_by(|a, b| a.path.cmp(&b.path));
p1.fragments[0].value = f.add(&p1.fragments[0].value, &BigUint::from_u64(1));
let coalition = pick(&shares, &[1, 2, 3]);
let got = reconstruct(&f, &formula, &coalition).expect("OR's first branch still 'recovers'");
assert_ne!(got, secret);
}
#[test]
fn solo_forger_at_others_path_is_rejected() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![Formula::party(1), Formula::party(2)]);
let secret = BigUint::from_u64(77);
let shares = split(&f, &mut r, &secret, &formula);
let p2_path = shares
.iter()
.find(|s| s.player == 2)
.unwrap()
.fragments[0]
.path
.clone();
let mut p1 = shares.iter().find(|s| s.player == 1).unwrap().clone();
p1.fragments.push(ShareFragment {
path: p2_path,
value: BigUint::from_u64(0xDEAD),
});
let coalition = vec![p1];
assert!(reconstruct(&f, &formula, &coalition).is_none());
}
#[test]
fn dishonest_replication_across_players_rejected() {
let f = small();
let mut r = rng();
let formula = Formula::or(vec![Formula::party(1), Formula::party(2)]);
let secret = BigUint::from_u64(55);
let shares = split(&f, &mut r, &secret, &formula);
let p2 = shares.iter().find(|s| s.player == 2).unwrap().clone();
let mut p1 = shares.iter().find(|s| s.player == 1).unwrap().clone();
p1.fragments.push(ShareFragment {
path: p2.fragments[0].path.clone(),
value: f.add(&p2.fragments[0].value, &BigUint::from_u64(1)),
});
let conflict = vec![p1, p2];
assert!(reconstruct(&f, &formula, &conflict).is_none());
}
}