const PRIME: u16 = 251;
fn get_random_bytes(n: usize) -> Vec<u8> {
let mut buf = vec![0u8; n];
getrandom::fill(&mut buf).expect("Failed to get random bytes");
buf
}
#[derive(Debug, Clone)]
pub struct Share {
pub index: u8,
pub data: Vec<u8>,
}
fn gf_inv(a: u8) -> u8 {
if a == 0 {
panic!("Cannot invert zero in GF(251)");
}
let mut result: u16 = 1;
let mut base: u16 = a as u16;
let mut exp: u16 = PRIME - 2;
while exp > 0 {
if exp % 2 == 1 {
result = (result * base) % PRIME;
}
base = (base * base) % PRIME;
exp /= 2;
}
result as u8
}
pub fn split(secret: &[u8], threshold: usize, num_shares: usize) -> Vec<Share> {
assert!(threshold >= 2, "Threshold must be >= 2");
assert!(num_shares >= threshold, "Shares must be >= threshold");
assert!(num_shares <= 250, "Shares must be <= 250 in GF(251)");
let mut shares: Vec<Share> = (0..num_shares)
.map(|i| Share {
index: (i + 1) as u8,
data: vec![0u8; secret.len()],
})
.collect();
for byte_idx in 0..secret.len() {
let secret_byte = secret[byte_idx];
let (b1, b2) = if secret_byte < PRIME as u8 {
(secret_byte, 0)
} else {
let b1 = 1u8;
let b2 = secret_byte ^ b1;
debug_assert!(b2 < PRIME as u8);
(b1, b2)
};
let mut coeffs = vec![0u8; threshold];
coeffs[0] = b1;
let rand_bytes = get_random_bytes(threshold - 1);
for (i, &b) in rand_bytes.iter().enumerate() {
coeffs[i + 1] = (b % (PRIME as u8 - 1)) + 1; }
for share in shares.iter_mut() {
let x = share.index as u16;
let mut y: u16 = 0;
let mut x_pow: u16 = 1;
for &coeff in &coeffs {
let term = (coeff as u16 * x_pow) % PRIME;
y = (y + term) % PRIME;
x_pow = (x_pow * x) % PRIME;
}
share.data[byte_idx] = y as u8;
}
if b2 != 0 {
let mut coeffs2 = vec![0u8; threshold];
coeffs2[0] = b2;
let rand_bytes2 = get_random_bytes(threshold - 1);
for (i, &b) in rand_bytes2.iter().enumerate() {
coeffs2[i + 1] = (b % (PRIME as u8 - 1)) + 1;
}
}
}
shares
}
pub fn reconstruct(shares: &[Share]) -> Vec<u8> {
assert!(!shares.is_empty(), "Need at least one share");
let secret_len = shares[0].data.len();
let t = shares.len();
let mut secret = vec![0u8; secret_len];
for byte_idx in 0..secret_len {
let mut result: u16 = 0;
for i in 0..t {
let xi = shares[i].index as u16;
let yi = shares[i].data[byte_idx] as u16;
let mut numerator: u16 = 1;
let mut denominator: u16 = 1;
for j in 0..t {
if i == j { continue; }
let xj = shares[j].index as u16;
numerator = (numerator * ((PRIME - xj) % PRIME)) % PRIME;
denominator = (denominator * ((xi + PRIME - xj) % PRIME)) % PRIME;
}
let lagrange_coef = (numerator * gf_inv(denominator as u8) as u16) % PRIME;
let term = (lagrange_coef * yi) % PRIME;
result = (result + term) % PRIME;
}
secret[byte_idx] = result as u8;
}
secret
}
pub fn encode_shares(shares: &[Share]) -> Vec<String> {
shares
.iter()
.map(|s| {
let encoded = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&s.data,
);
format!("{}:{}", s.index, encoded)
})
.collect()
}
pub fn decode_shares(encoded: &[&str]) -> Result<Vec<Share>, String> {
encoded
.iter()
.map(|s| {
let parts: Vec<&str> = s.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(format!("Invalid share format: {}", s));
}
let index: u8 = parts[0].parse().map_err(|e| format!("Invalid index: {}", e))?;
let data = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
parts[1],
)
.map_err(|e| format!("Invalid base64: {}", e))?;
Ok(Share { index, data })
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_and_reconstruct() {
let secret = b"top-secret-api-key";
let shares = split(secret, 3, 5);
assert_eq!(shares.len(), 5);
let subset = vec![shares[0].clone(), shares[2].clone(), shares[4].clone()];
let recovered = reconstruct(&subset);
assert_eq!(recovered, secret.to_vec());
}
#[test]
fn test_reconstruct_with_different_shares() {
let secret = b"another-secret";
let shares = split(secret, 2, 4);
let recovered1 = reconstruct(&[shares[0].clone(), shares[1].clone()]);
let recovered2 = reconstruct(&[shares[2].clone(), shares[3].clone()]);
let recovered3 = reconstruct(&[shares[1].clone(), shares[3].clone()]);
assert_eq!(recovered1, secret.to_vec());
assert_eq!(recovered2, secret.to_vec());
assert_eq!(recovered3, secret.to_vec());
}
#[test]
fn test_insufficient_shares_fail() {
let secret = b"secret";
let shares = split(secret, 3, 5);
let recovered = reconstruct(&[shares[0].clone(), shares[1].clone()]);
assert_ne!(recovered, secret.to_vec());
}
#[test]
fn test_encode_decode() {
let secret = b"test-secret";
let shares = split(secret, 2, 3);
let encoded = encode_shares(&shares);
assert_eq!(encoded.len(), 3);
let refs: Vec<&str> = encoded.iter().map(|s| s.as_str()).collect();
let decoded = decode_shares(&refs).unwrap();
let recovered = reconstruct(&decoded);
assert_eq!(recovered, secret.to_vec());
}
#[test]
fn test_all_shares_reconstruct() {
let secret = b"full-reconstruction";
let shares = split(secret, 4, 4);
let recovered = reconstruct(&shares);
assert_eq!(recovered, secret.to_vec());
}
#[test]
fn test_threshold_2_of_3() {
let secret = b"abir-guard-test";
let shares = split(secret, 2, 3);
for i in 0..3 {
for j in (i+1)..3 {
let recovered = reconstruct(&[shares[i].clone(), shares[j].clone()]);
assert_eq!(recovered, secret.to_vec(), "Failed with shares {} and {}", i, j);
}
}
}
}