use crate::field::PrimeField;
use crate::poly::lagrange_eval;
use crate::shamir::Share;
use crate::bigint::BigUint;
#[must_use]
pub fn split(field: &PrimeField, secret: &[BigUint], n: usize) -> Vec<Share> {
let k = secret.len();
assert!(
k >= 2,
"secret must have ≥ 2 components (k = 1 makes every share equal to the secret)"
);
assert!(
n >= k,
"n must be ≥ k = secret.len() — otherwise no subset of the n shares can reconstruct the k-element secret",
);
assert!(
BigUint::from_u64((k + n) as u64) < *field.modulus(),
"prime modulus must exceed k + n",
);
let anchors: Vec<(BigUint, BigUint)> = secret
.iter()
.enumerate()
.map(|(i, b)| (BigUint::from_u64((i + 1) as u64), field.reduce(b)))
.collect();
(1..=n)
.map(|i| {
let x = BigUint::from_u64((k + i) as u64);
let y = lagrange_eval(field, &anchors, &x).expect("distinct anchors");
Share { x, y }
})
.collect()
}
#[must_use]
pub fn reconstruct(field: &PrimeField, shares: &[Share], k: usize) -> Option<Vec<BigUint>> {
if shares.is_empty() || k == 0 || shares.len() < k {
return None;
}
let k_big = BigUint::from_u64(k as u64);
for s in shares {
let xr = field.reduce(&s.x);
if xr.is_zero() || xr <= k_big {
return None;
}
}
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y.clone()))
.collect();
let mut out = Vec::with_capacity(k);
for j in 1..=k {
let xj = BigUint::from_u64(j as u64);
out.push(lagrange_eval(field, &pts, &xj)?);
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn small_field() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
#[test]
fn ramp_round_trip() {
let f = small_field();
let secret: Vec<BigUint> = (1..=4).map(|i| BigUint::from_u64(100 + i)).collect();
let n = 6;
let shares = split(&f, &secret, n);
assert_eq!(shares.len(), n);
assert_eq!(reconstruct(&f, &shares[..4], 4), Some(secret.clone()));
assert_eq!(reconstruct(&f, &shares[2..], 4), Some(secret));
}
#[test]
fn ramp_payload_is_one_field_element_per_share() {
let f = small_field();
let secret: Vec<BigUint> = (1..=10).map(BigUint::from_u64).collect();
let shares = split(&f, &secret, 12);
for s in &shares {
let _ = s.y.clone();
}
assert_eq!(reconstruct(&f, &shares[..10], 10).unwrap(), secret);
}
#[test]
#[should_panic(expected = "n must be ≥ k")]
fn ramp_split_rejects_n_below_k() {
let f = small_field();
let secret: Vec<BigUint> = (1..=4).map(BigUint::from_u64).collect();
let _ = split(&f, &secret, 3);
}
#[test]
fn ramp_split_rejects_secret_anchor_labels() {
let f = small_field();
let secret: Vec<BigUint> = (1..=3).map(BigUint::from_u64).collect();
let mut shares = split(&f, &secret, 5);
shares[0].x = BigUint::from_u64(2);
assert!(reconstruct(&f, &shares[..3], 3).is_none());
}
#[test]
fn ramp_below_threshold_returns_none() {
let f = small_field();
let secret: Vec<BigUint> = (1..=3).map(|i| BigUint::from_u64(10 + i)).collect();
let shares = split(&f, &secret, 5);
assert!(reconstruct(&f, &shares[..2], 3).is_none());
}
}