use sharks::{Share, Sharks};
use crate::error::CollabError;
pub fn split_recovery_key(key: &[u8; 32], threshold: u8, shares: u8) -> Vec<Vec<u8>> {
assert!(threshold > 0, "threshold must be at least 1");
assert!(threshold <= shares, "threshold must be ≤ shares");
let sharks = Sharks(threshold);
let dealer = sharks.dealer(key);
dealer
.take(shares as usize)
.map(|share| {
let mut bytes = Vec::with_capacity(1 + 1 + 32);
bytes.push(threshold);
bytes.extend(Vec::from(&share));
bytes
})
.collect()
}
pub fn reconstruct_recovery_key(shares: &[Vec<u8>]) -> Result<[u8; 32], CollabError> {
if shares.is_empty() {
return Err(CollabError::RecoveryFailed("no shares provided".to_owned()));
}
let threshold = shares[0]
.first()
.copied()
.ok_or_else(|| CollabError::RecoveryFailed("share is empty".to_owned()))?;
let parsed: Result<Vec<Share>, _> = shares
.iter()
.map(|raw| {
raw.get(1..)
.ok_or("share too short")
.and_then(Share::try_from)
})
.collect();
let parsed = parsed.map_err(|e| {
CollabError::RecoveryFailed(format!("one or more shares could not be parsed: {e}"))
})?;
let sharks = Sharks(threshold);
let secret = sharks
.recover(&parsed)
.map_err(|e| CollabError::RecoveryFailed(format!("Shamir reconstruction error: {e}")))?;
secret.as_slice().try_into().map_err(|_| {
CollabError::RecoveryFailed(format!(
"reconstructed secret is {} bytes, expected 32",
secret.len()
))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_2_of_3() {
let key: [u8; 32] = [0xAB; 32];
let all_shares = split_recovery_key(&key, 2, 3);
assert_eq!(all_shares.len(), 3);
let recovered = reconstruct_recovery_key(&all_shares[..2]).unwrap();
assert_eq!(recovered, key);
}
#[test]
fn round_trip_3_of_5() {
let key: [u8; 32] = [0x42; 32];
let all_shares = split_recovery_key(&key, 3, 5);
assert_eq!(all_shares.len(), 5);
let subset = vec![
all_shares[0].clone(),
all_shares[2].clone(),
all_shares[4].clone(),
];
let recovered = reconstruct_recovery_key(&subset).unwrap();
assert_eq!(recovered, key);
}
#[test]
fn all_5_shares_reconstruct() {
let key: [u8; 32] = [0xDE; 32];
let all_shares = split_recovery_key(&key, 3, 5);
let recovered = reconstruct_recovery_key(&all_shares).unwrap();
assert_eq!(recovered, key);
}
#[test]
fn empty_shares_returns_error() {
let result = reconstruct_recovery_key(&[]);
assert!(matches!(result, Err(CollabError::RecoveryFailed(_))));
}
#[test]
fn too_few_shares_returns_error() {
let key: [u8; 32] = [0x11; 32];
let all_shares = split_recovery_key(&key, 3, 5);
let result = reconstruct_recovery_key(&all_shares[..2]);
assert!(
matches!(result, Err(CollabError::RecoveryFailed(_))),
"expected RecoveryFailed with too few shares, got: {result:?}"
);
}
#[test]
fn corrupted_share_prefix_returns_error() {
let key: [u8; 32] = [0x01; 32];
let mut shares = split_recovery_key(&key, 2, 3);
if let Some(first) = shares.first_mut() {
for b in first.iter_mut().skip(1) {
*b ^= 0xFF;
}
}
let result = reconstruct_recovery_key(&shares[..2]);
assert!(result.is_err() || result.unwrap() != key);
}
}