use crate::error::{Error, Result};
use crate::primitives::bsv::polynomial::{PointInFiniteField, Polynomial};
use crate::primitives::encoding::{from_base58, to_base58};
use crate::primitives::hash::sha256;
use crate::primitives::BigNumber;
use crate::primitives::PrivateKey;
#[derive(Clone, Debug)]
pub struct KeyShares {
pub points: Vec<PointInFiniteField>,
pub threshold: usize,
pub integrity: String,
}
impl KeyShares {
pub fn new(points: Vec<PointInFiniteField>, threshold: usize, integrity: String) -> Self {
Self {
points,
threshold,
integrity,
}
}
pub fn from_backup_format(shares: &[String]) -> Result<Self> {
if shares.is_empty() {
return Err(Error::CryptoError(
"No shares provided for recovery".to_string(),
));
}
let mut points = Vec::with_capacity(shares.len());
let mut threshold: Option<usize> = None;
let mut integrity: Option<String> = None;
for (idx, share) in shares.iter().enumerate() {
let (point, t, i) = decode_share(share)?;
if let Some(existing_threshold) = threshold {
if existing_threshold != t {
return Err(Error::CryptoError(format!(
"Threshold mismatch: share 0 has threshold {}, share {} has threshold {}",
existing_threshold, idx, t
)));
}
} else {
threshold = Some(t);
}
if let Some(ref existing_integrity) = integrity {
if existing_integrity != &i {
return Err(Error::CryptoError(format!(
"Integrity mismatch: share 0 has integrity '{}', share {} has integrity '{}'",
existing_integrity, idx, i
)));
}
} else {
integrity = Some(i);
}
points.push(point);
}
Ok(Self {
points,
threshold: threshold.unwrap(),
integrity: integrity.unwrap(),
})
}
pub fn to_backup_format(&self) -> Vec<String> {
self.points
.iter()
.map(|point| {
format!(
"{}.{}.{}",
point.to_point_string(),
self.threshold,
self.integrity
)
})
.collect()
}
pub fn recover_private_key(&self) -> Result<PrivateKey> {
if self.points.len() < self.threshold {
return Err(Error::CryptoError(format!(
"Insufficient shares: have {}, need {}",
self.points.len(),
self.threshold
)));
}
let poly = Polynomial::new(self.points.clone(), self.threshold);
let secret = poly.value_at(&BigNumber::zero());
let secret_bytes = secret.to_bytes_be(32);
let key = PrivateKey::from_bytes(&secret_bytes)?;
let computed_integrity = compute_integrity(&key);
if computed_integrity != self.integrity {
return Err(Error::CryptoError(format!(
"Integrity check failed: computed '{}', expected '{}'",
computed_integrity, self.integrity
)));
}
Ok(key)
}
}
pub fn split_private_key(key: &PrivateKey, threshold: usize, total: usize) -> Result<KeyShares> {
if threshold < 2 {
return Err(Error::CryptoError(
"Threshold must be at least 2".to_string(),
));
}
if total < threshold {
return Err(Error::CryptoError(format!(
"Total shares ({}) must be at least threshold ({})",
total, threshold
)));
}
if threshold > 255 {
return Err(Error::CryptoError(
"Threshold cannot exceed 255".to_string(),
));
}
let p = BigNumber::secp256k1_prime();
let secret = BigNumber::from_bytes_be(&key.to_bytes());
let mut coefficients = Vec::with_capacity(threshold);
coefficients.push(secret);
for _ in 1..threshold {
let random_key = PrivateKey::random();
let coeff = BigNumber::from_bytes_be(&random_key.to_bytes()).modulo(&p);
coefficients.push(coeff);
}
let mut points = Vec::with_capacity(total);
for i in 1..=total {
let x = BigNumber::from_u64(i as u64);
let y = evaluate_polynomial(&coefficients, &x, &p);
points.push(PointInFiniteField::new(x, y));
}
let integrity = compute_integrity(key);
Ok(KeyShares {
points,
threshold,
integrity,
})
}
fn evaluate_polynomial(
coefficients: &[BigNumber],
x: &BigNumber,
modulus: &BigNumber,
) -> BigNumber {
let mut result = BigNumber::zero();
for coeff in coefficients.iter().rev() {
result = result.mul(x).add(coeff).modulo(modulus);
}
result
}
fn compute_integrity(key: &PrivateKey) -> String {
let hash = sha256(&key.to_bytes());
let b58 = to_base58(&hash);
if b58.len() >= 4 {
b58[..4].to_string()
} else {
b58
}
}
fn decode_share(share: &str) -> Result<(PointInFiniteField, usize, String)> {
let components: Vec<&str> = share.split('.').collect();
if components.len() != 4 {
return Err(Error::CryptoError(format!(
"Invalid share format: expected 'base58(x).base58(y).threshold.integrity', got '{}'",
share
)));
}
let x_bytes = from_base58(components[0])?;
let y_bytes = from_base58(components[1])?;
let x = BigNumber::from_bytes_be(&x_bytes);
let y = BigNumber::from_bytes_be(&y_bytes);
let point = PointInFiniteField::new(x, y);
let threshold: usize = components[2].parse().map_err(|e| {
Error::CryptoError(format!(
"Invalid threshold in share: {} ({})",
components[2], e
))
})?;
let integrity = components[3].to_string();
Ok((point, threshold, integrity))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_recover_roundtrip() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
assert_eq!(shares.points.len(), 5);
assert_eq!(shares.threshold, 3);
let subset = KeyShares::new(shares.points[0..3].to_vec(), 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_split_recover_different_subsets() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let subsets = vec![
vec![0, 1, 2],
vec![0, 1, 3],
vec![0, 1, 4],
vec![0, 2, 3],
vec![0, 2, 4],
vec![0, 3, 4],
vec![1, 2, 3],
vec![1, 2, 4],
vec![1, 3, 4],
vec![2, 3, 4],
];
for indices in subsets {
let points: Vec<_> = indices.iter().map(|&i| shares.points[i].clone()).collect();
let subset = KeyShares::new(points, 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(
key.to_bytes(),
recovered.to_bytes(),
"Failed for indices {:?}",
indices
);
}
}
#[test]
fn test_backup_format_roundtrip() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let backup = shares.to_backup_format();
assert_eq!(backup.len(), 5);
for s in &backup {
assert_eq!(s.split('.').count(), 4);
}
let restored = KeyShares::from_backup_format(&backup[1..4]).unwrap();
let recovered = restored.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_minimum_threshold() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 2, 3).unwrap();
let subset = KeyShares::new(shares.points[0..2].to_vec(), 2, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_exact_threshold_equals_total() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 5, 5).unwrap();
let recovered = shares.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_insufficient_shares() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let subset = KeyShares::new(shares.points[0..2].to_vec(), 3, shares.integrity.clone());
let result = subset.recover_private_key();
assert!(matches!(result, Err(Error::CryptoError(_))));
}
#[test]
fn test_invalid_threshold() {
let key = PrivateKey::random();
assert!(matches!(
split_private_key(&key, 1, 5),
Err(Error::CryptoError(_))
));
assert!(matches!(
split_private_key(&key, 0, 5),
Err(Error::CryptoError(_))
));
assert!(matches!(
split_private_key(&key, 5, 3),
Err(Error::CryptoError(_))
));
}
#[test]
fn test_integrity_check_fails_on_corruption() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 2, 3).unwrap();
let corrupted = KeyShares::new(
shares.points.clone(),
2,
"XXXX".to_string(), );
let result = corrupted.recover_private_key();
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(result
.unwrap_err()
.to_string()
.contains("Integrity check failed"));
}
#[test]
fn test_backup_format_parsing() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 2, 3).unwrap();
let backup = shares.to_backup_format();
let parsed = KeyShares::from_backup_format(&backup).unwrap();
assert_eq!(parsed.threshold, shares.threshold);
assert_eq!(parsed.integrity, shares.integrity);
assert_eq!(parsed.points.len(), shares.points.len());
}
#[test]
fn test_mismatched_threshold_in_shares() {
let share1 = "2.abc.3.XXXX".to_string();
let share2 = "3.def.4.XXXX".to_string();
let result = KeyShares::from_backup_format(&[share1, share2]);
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(result
.unwrap_err()
.to_string()
.contains("Threshold mismatch"));
}
#[test]
fn test_mismatched_integrity_in_shares() {
let share1 = "2.abc.3.AAAA".to_string();
let share2 = "3.def.3.BBBB".to_string();
let result = KeyShares::from_backup_format(&[share1, share2]);
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(result
.unwrap_err()
.to_string()
.contains("Integrity mismatch"));
}
#[test]
fn test_known_private_key() {
let key = PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000001",
)
.unwrap();
let shares = split_private_key(&key, 2, 3).unwrap();
let recovered = shares.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_large_number_of_shares() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 5, 10).unwrap();
assert_eq!(shares.points.len(), 10);
let subset = KeyShares::new(
shares.points[5..10].to_vec(), 5,
shares.integrity.clone(),
);
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_compute_integrity() {
let key = PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000001",
)
.unwrap();
let integrity = compute_integrity(&key);
assert_eq!(integrity.len(), 4);
let integrity2 = compute_integrity(&key);
assert_eq!(integrity, integrity2);
}
#[test]
fn test_evaluate_polynomial() {
let p = BigNumber::secp256k1_prime();
let coefficients = vec![
BigNumber::from_u64(5),
BigNumber::from_u64(3),
BigNumber::from_u64(2),
];
assert_eq!(
evaluate_polynomial(&coefficients, &BigNumber::zero(), &p),
BigNumber::from_u64(5)
);
assert_eq!(
evaluate_polynomial(&coefficients, &BigNumber::from_u64(1), &p),
BigNumber::from_u64(10)
);
assert_eq!(
evaluate_polynomial(&coefficients, &BigNumber::from_u64(2), &p),
BigNumber::from_u64(19)
);
}
#[test]
fn test_decode_share() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 2, 3).unwrap();
let backup = shares.to_backup_format();
let (point, threshold, integrity) = decode_share(&backup[0]).unwrap();
assert_eq!(threshold, 2);
assert_eq!(integrity, shares.integrity);
assert_eq!(point.x, shares.points[0].x);
assert_eq!(point.y, shares.points[0].y);
}
#[test]
fn test_decode_share_invalid_format() {
assert!(matches!(decode_share("a.b.c"), Err(Error::CryptoError(_))));
assert!(matches!(
decode_share("a.b.c.d.e"),
Err(Error::CryptoError(_))
));
assert!(matches!(
decode_share("2.abc.notanumber.XXXX"),
Err(Error::CryptoError(_))
));
}
#[test]
fn test_empty_shares() {
let result = KeyShares::from_backup_format(&[]);
assert!(matches!(result, Err(Error::CryptoError(_))));
}
#[test]
fn test_threshold_greater_than_total_shares() {
let key = PrivateKey::random();
let result = split_private_key(&key, 50, 5);
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(
result.unwrap_err().to_string().contains("must be at least"),
"Expected error about total shares being less than threshold"
);
}
#[test]
fn test_total_shares_less_than_2() {
let key = PrivateKey::random();
let result = split_private_key(&key, 2, 1);
assert!(matches!(result, Err(Error::CryptoError(_))));
let result = split_private_key(&key, 1, 1);
assert!(matches!(result, Err(Error::CryptoError(_))));
}
#[test]
fn test_duplicate_shares_in_recovery() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let backup = shares.to_backup_format();
let recovery = KeyShares::from_backup_format(&[
backup[0].clone(),
backup[1].clone(),
backup[1].clone(),
])
.unwrap();
let result = recovery.recover_private_key();
assert!(
matches!(result, Err(Error::CryptoError(_))),
"Expected CryptoError when using duplicate shares for recovery, got {:?}",
result
);
}
#[test]
fn test_fewer_points_than_threshold() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let subset = KeyShares::new(shares.points[..2].to_vec(), 3, shares.integrity.clone());
let result = subset.recover_private_key();
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(
result
.unwrap_err()
.to_string()
.contains("Insufficient shares"),
"Expected 'Insufficient shares' error message"
);
}
#[test]
fn test_consistency_across_multiple_splits() {
let key = PrivateKey::random();
let shares1 = split_private_key(&key, 3, 5).unwrap();
let shares2 = split_private_key(&key, 3, 5).unwrap();
assert_ne!(
shares1.points[0].y, shares2.points[0].y,
"Two splits of the same key should produce different shares due to randomness"
);
let recovered1 = KeyShares::new(shares1.points[..3].to_vec(), 3, shares1.integrity.clone())
.recover_private_key()
.unwrap();
let recovered2 = KeyShares::new(shares2.points[..3].to_vec(), 3, shares2.integrity.clone())
.recover_private_key()
.unwrap();
assert_eq!(key.to_bytes(), recovered1.to_bytes());
assert_eq!(key.to_bytes(), recovered2.to_bytes());
assert_eq!(shares1.integrity, shares2.integrity);
}
#[test]
fn test_different_recovery_subsets() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let all_subsets: Vec<Vec<usize>> = vec![
vec![0, 1, 2],
vec![0, 1, 3],
vec![0, 1, 4],
vec![0, 2, 3],
vec![0, 2, 4],
vec![0, 3, 4],
vec![1, 2, 3],
vec![1, 2, 4],
vec![1, 3, 4],
vec![2, 3, 4],
];
for subset_indices in &all_subsets {
let subset_points: Vec<_> = subset_indices
.iter()
.map(|&i| shares.points[i].clone())
.collect();
let subset = KeyShares::new(subset_points, 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(
key.to_bytes(),
recovered.to_bytes(),
"Recovery failed for subset {:?}",
subset_indices
);
}
}
#[test]
fn test_single_share_threshold() {
let key = PrivateKey::random();
let result = split_private_key(&key, 1, 5);
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(
result.unwrap_err().to_string().contains("at least 2"),
"Expected error about threshold being at least 2"
);
}
#[test]
fn test_max_shares() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 255).unwrap();
assert_eq!(shares.points.len(), 255);
let subset = KeyShares::new(shares.points[..3].to_vec(), 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
let subset = KeyShares::new(
shares.points[252..255].to_vec(),
3,
shares.integrity.clone(),
);
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
let subset = KeyShares::new(
vec![
shares.points[0].clone(),
shares.points[127].clone(),
shares.points[254].clone(),
],
3,
shares.integrity.clone(),
);
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_threshold_exceeds_255() {
let key = PrivateKey::random();
let result = split_private_key(&key, 256, 300);
assert!(matches!(result, Err(Error::CryptoError(_))));
assert!(
result.unwrap_err().to_string().contains("255"),
"Expected error about threshold exceeding 255"
);
}
#[test]
fn test_different_thresholds_and_shares() {
let test_cases = vec![(2, 3), (2, 5), (3, 5), (4, 7), (5, 10), (10, 10)];
for (threshold, total) in test_cases {
let key = PrivateKey::random();
let shares = split_private_key(&key, threshold, total).unwrap();
assert_eq!(shares.points.len(), total);
let subset = KeyShares::new(
shares.points[..threshold].to_vec(),
threshold,
shares.integrity.clone(),
);
let recovered = subset.recover_private_key().unwrap();
assert_eq!(
key.to_bytes(),
recovered.to_bytes(),
"Failed for threshold={}, total={}",
threshold,
total
);
}
}
#[test]
fn test_recovery_with_more_shares_than_threshold() {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let recovered = shares.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
let subset = KeyShares::new(shares.points[..4].to_vec(), 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
#[test]
fn test_recovery_with_wrong_shares_fails_integrity() {
let key1 = PrivateKey::random();
let key2 = PrivateKey::random();
let shares1 = split_private_key(&key1, 2, 3).unwrap();
let shares2 = split_private_key(&key2, 2, 3).unwrap();
let mixed = KeyShares::new(
vec![shares1.points[0].clone(), shares2.points[1].clone()],
2,
shares1.integrity.clone(),
);
let result = mixed.recover_private_key();
assert!(
matches!(result, Err(Error::CryptoError(_))),
"Expected CryptoError when mixing shares from different keys, got {:?}",
result
);
}
#[test]
fn test_multiple_recovery_iterations() {
for _ in 0..10 {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let subset = KeyShares::new(shares.points[..3].to_vec(), 3, shares.integrity.clone());
let recovered = subset.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered.to_bytes());
}
}
#[test]
fn test_backup_recovery_full_roundtrip() {
for _ in 0..3 {
let key = PrivateKey::random();
let shares = split_private_key(&key, 3, 5).unwrap();
let backup = shares.to_backup_format();
assert_eq!(backup.len(), 5);
let recovered_shares = KeyShares::from_backup_format(&backup[..3]).unwrap();
let recovered_key = recovered_shares.recover_private_key().unwrap();
assert_eq!(key.to_bytes(), recovered_key.to_bytes());
}
}
#[test]
fn test_zero_threshold() {
let key = PrivateKey::random();
let result = split_private_key(&key, 0, 5);
assert!(matches!(result, Err(Error::CryptoError(_))));
}
#[test]
fn test_zero_total_shares() {
let key = PrivateKey::random();
let result = split_private_key(&key, 2, 0);
assert!(matches!(result, Err(Error::CryptoError(_))));
}
}