use crate::encryption::{decrypt, encrypt, generate_nonce};
use crate::hash::hash;
use crate::kdf::hkdf_extract_expand;
use crate::shamir::{Share, reconstruct, split};
use crate::signing::KeyPair;
use rand::Rng as _;
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub enum BackupError {
InvalidThreshold(String),
InsufficientShares(String),
InvalidShare(String),
CryptoError(String),
SerializationError(String),
InvalidPassword,
VersionMismatch(String),
}
impl std::fmt::Display for BackupError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackupError::InvalidThreshold(msg) => write!(f, "Invalid threshold: {}", msg),
BackupError::InsufficientShares(msg) => write!(f, "Insufficient shares: {}", msg),
BackupError::InvalidShare(msg) => write!(f, "Invalid share: {}", msg),
BackupError::CryptoError(msg) => write!(f, "Crypto error: {}", msg),
BackupError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
BackupError::InvalidPassword => write!(f, "Invalid password"),
BackupError::VersionMismatch(msg) => write!(f, "Version mismatch: {}", msg),
}
}
}
impl std::error::Error for BackupError {}
pub type BackupResult<T> = Result<T, BackupError>;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum KeyType {
SigningKey,
EncryptionKey,
GenericSecret,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BackupConfig {
pub threshold: usize,
pub total_shares: usize,
pub label: Option<String>,
pub description: Option<String>,
pub key_type: KeyType,
pub version: u32,
}
impl BackupConfig {
pub fn new(threshold: usize, total_shares: usize) -> Self {
Self {
threshold,
total_shares,
label: None,
description: None,
key_type: KeyType::GenericSecret,
version: 1,
}
}
pub fn with_label(mut self, label: &str) -> Self {
self.label = Some(label.to_string());
self
}
pub fn with_description(mut self, description: &str) -> Self {
self.description = Some(description.to_string());
self
}
pub fn with_key_type(mut self, key_type: KeyType) -> Self {
self.key_type = key_type;
self
}
pub fn with_version(mut self, version: u32) -> Self {
self.version = version;
self
}
pub fn validate(&self) -> BackupResult<()> {
if self.threshold == 0 {
return Err(BackupError::InvalidThreshold(
"Threshold must be at least 1".to_string(),
));
}
if self.threshold > self.total_shares {
return Err(BackupError::InvalidThreshold(format!(
"Threshold ({}) cannot exceed total shares ({})",
self.threshold, self.total_shares
)));
}
if self.total_shares > 255 {
return Err(BackupError::InvalidThreshold(
"Total shares cannot exceed 255".to_string(),
));
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BackupShare {
pub index: u8,
pub share_data: Vec<u8>,
pub config: BackupConfig,
pub created_at: u64,
pub checksum: [u8; 32],
}
impl BackupShare {
fn new(index: u8, share: Share, config: BackupConfig) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let share_data = share.data.clone();
let mut data = Vec::new();
data.push(index);
data.push(share.index);
data.extend_from_slice(&share_data);
data.extend_from_slice(×tamp.to_le_bytes());
let checksum = hash(&data);
Self {
index,
share_data,
config,
created_at: timestamp,
checksum,
}
}
pub fn verify_integrity(&self) -> bool {
let mut data = Vec::new();
data.push(self.index);
data.push(self.index); data.extend_from_slice(&self.share_data);
data.extend_from_slice(&self.created_at.to_le_bytes());
let expected_checksum = hash(&data);
expected_checksum == self.checksum
}
fn to_share(&self) -> BackupResult<Share> {
Share::new(self.index, self.share_data.clone())
.map_err(|e| BackupError::InvalidShare(e.to_string()))
}
pub fn to_bytes(&self) -> BackupResult<Vec<u8>> {
crate::codec::encode(self).map_err(|e| BackupError::SerializationError(e.to_string()))
}
pub fn from_bytes(bytes: &[u8]) -> BackupResult<Self> {
crate::codec::decode(bytes).map_err(|e| BackupError::SerializationError(e.to_string()))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EncryptedBackup {
pub ciphertext: Vec<u8>,
pub nonce: [u8; 12],
pub salt: [u8; 32],
pub config: BackupConfig,
pub created_at: u64,
}
impl EncryptedBackup {
pub fn to_bytes(&self) -> BackupResult<Vec<u8>> {
crate::codec::encode(self).map_err(|e| BackupError::SerializationError(e.to_string()))
}
pub fn from_bytes(bytes: &[u8]) -> BackupResult<Self> {
crate::codec::decode(bytes).map_err(|e| BackupError::SerializationError(e.to_string()))
}
}
pub fn backup_key_shamir(
keypair: &KeyPair,
config: &BackupConfig,
) -> BackupResult<Vec<BackupShare>> {
config.validate()?;
let secret = keypair.secret_key();
let shares = split(&secret, config.threshold, config.total_shares)
.map_err(|e| BackupError::CryptoError(e.to_string()))?;
let backup_shares: Vec<BackupShare> = shares
.into_iter()
.enumerate()
.map(|(i, share)| BackupShare::new((i + 1) as u8, share, config.clone()))
.collect();
Ok(backup_shares)
}
pub fn recover_key_shamir(shares: &[BackupShare]) -> BackupResult<KeyPair> {
if shares.is_empty() {
return Err(BackupError::InsufficientShares(
"No shares provided".to_string(),
));
}
let config = &shares[0].config;
if shares.len() < config.threshold {
return Err(BackupError::InsufficientShares(format!(
"Need {} shares but only {} provided",
config.threshold,
shares.len()
)));
}
for share in shares {
if !share.verify_integrity() {
return Err(BackupError::InvalidShare(format!(
"Share {} failed integrity check",
share.index
)));
}
if share.config.threshold != config.threshold {
return Err(BackupError::InvalidShare(
"Incompatible share thresholds".to_string(),
));
}
}
let raw_shares: Vec<Share> = shares
.iter()
.map(|bs| bs.to_share())
.collect::<Result<Vec<_>, _>>()?;
let secret = reconstruct(&raw_shares).map_err(|e| BackupError::CryptoError(e.to_string()))?;
if secret.len() != 32 {
return Err(BackupError::CryptoError(
"Invalid secret length".to_string(),
));
}
let mut secret_array = [0u8; 32];
secret_array.copy_from_slice(&secret);
KeyPair::from_secret_key(&secret_array).map_err(|e| BackupError::CryptoError(e.to_string()))
}
pub fn backup_secret_shamir(
secret: &[u8],
config: &BackupConfig,
) -> BackupResult<Vec<BackupShare>> {
config.validate()?;
let shares = split(secret, config.threshold, config.total_shares)
.map_err(|e| BackupError::CryptoError(e.to_string()))?;
let backup_shares: Vec<BackupShare> = shares
.into_iter()
.enumerate()
.map(|(i, share)| BackupShare::new((i + 1) as u8, share, config.clone()))
.collect();
Ok(backup_shares)
}
pub fn recover_secret_shamir(shares: &[BackupShare]) -> BackupResult<Vec<u8>> {
if shares.is_empty() {
return Err(BackupError::InsufficientShares(
"No shares provided".to_string(),
));
}
let config = &shares[0].config;
if shares.len() < config.threshold {
return Err(BackupError::InsufficientShares(format!(
"Need {} shares but only {} provided",
config.threshold,
shares.len()
)));
}
for share in shares {
if !share.verify_integrity() {
return Err(BackupError::InvalidShare(format!(
"Share {} failed integrity check",
share.index
)));
}
}
let raw_shares: Vec<Share> = shares
.iter()
.map(|bs| bs.to_share())
.collect::<Result<Vec<_>, _>>()?;
reconstruct(&raw_shares).map_err(|e| BackupError::CryptoError(e.to_string()))
}
pub fn backup_key_encrypted(
keypair: &KeyPair,
password: &str,
config: &BackupConfig,
) -> BackupResult<EncryptedBackup> {
config.validate()?;
let mut salt = [0u8; 32];
rand::rng().fill_bytes(&mut salt);
let key_bytes = hkdf_extract_expand(password.as_bytes(), &salt, b"chie-backup-encryption-v1");
let secret = keypair.secret_key();
let nonce = generate_nonce();
let ciphertext = encrypt(&secret, &key_bytes, &nonce)
.map_err(|e| BackupError::CryptoError(e.to_string()))?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let nonce_bytes: [u8; 12] = nonce.as_slice().try_into().unwrap();
Ok(EncryptedBackup {
ciphertext,
nonce: nonce_bytes,
salt,
config: config.clone(),
created_at: timestamp,
})
}
pub fn recover_key_encrypted(backup: &EncryptedBackup, password: &str) -> BackupResult<KeyPair> {
let key_bytes = hkdf_extract_expand(
password.as_bytes(),
&backup.salt,
b"chie-backup-encryption-v1",
);
let nonce = &backup.nonce;
let secret =
decrypt(&backup.ciphertext, &key_bytes, nonce).map_err(|_| BackupError::InvalidPassword)?;
if secret.len() != 32 {
return Err(BackupError::CryptoError(
"Invalid secret length".to_string(),
));
}
let mut secret_array = [0u8; 32];
secret_array.copy_from_slice(&secret);
KeyPair::from_secret_key(&secret_array).map_err(|e| BackupError::CryptoError(e.to_string()))
}
pub fn backup_secret_encrypted(
secret: &[u8],
password: &str,
config: &BackupConfig,
) -> BackupResult<EncryptedBackup> {
config.validate()?;
let mut salt = [0u8; 32];
rand::rng().fill_bytes(&mut salt);
let key_bytes = hkdf_extract_expand(password.as_bytes(), &salt, b"chie-backup-encryption-v1");
let nonce = generate_nonce();
let ciphertext =
encrypt(secret, &key_bytes, &nonce).map_err(|e| BackupError::CryptoError(e.to_string()))?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let nonce_bytes: [u8; 12] = nonce.as_slice().try_into().unwrap();
Ok(EncryptedBackup {
ciphertext,
nonce: nonce_bytes,
salt,
config: config.clone(),
created_at: timestamp,
})
}
pub fn recover_secret_encrypted(backup: &EncryptedBackup, password: &str) -> BackupResult<Vec<u8>> {
let key_bytes = hkdf_extract_expand(
password.as_bytes(),
&backup.salt,
b"chie-backup-encryption-v1",
);
let nonce = &backup.nonce;
decrypt(&backup.ciphertext, &key_bytes, nonce).map_err(|_| BackupError::InvalidPassword)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shamir_backup_recovery() {
let keypair = KeyPair::generate();
let config = BackupConfig::new(3, 5).with_label("test-key");
let shares = backup_key_shamir(&keypair, &config).unwrap();
assert_eq!(shares.len(), 5);
for (i, share) in shares.iter().enumerate() {
assert_eq!(share.index, (i + 1) as u8);
assert!(share.verify_integrity());
}
let recovered = recover_key_shamir(&shares[0..3]).unwrap();
assert_eq!(keypair.public_key(), recovered.public_key());
let recovered = recover_key_shamir(&shares[1..5]).unwrap();
assert_eq!(keypair.public_key(), recovered.public_key());
}
#[test]
fn test_shamir_insufficient_shares() {
let keypair = KeyPair::generate();
let config = BackupConfig::new(3, 5);
let shares = backup_key_shamir(&keypair, &config).unwrap();
let result = recover_key_shamir(&shares[0..2]);
assert!(result.is_err());
}
#[test]
fn test_encrypted_backup_recovery() {
let keypair = KeyPair::generate();
let password = "secure_password_123";
let config = BackupConfig::new(1, 1).with_key_type(KeyType::SigningKey);
let backup = backup_key_encrypted(&keypair, password, &config).unwrap();
let recovered = recover_key_encrypted(&backup, password).unwrap();
assert_eq!(keypair.public_key(), recovered.public_key());
}
#[test]
fn test_encrypted_backup_wrong_password() {
let keypair = KeyPair::generate();
let password = "correct_password";
let wrong_password = "wrong_password";
let config = BackupConfig::new(1, 1);
let backup = backup_key_encrypted(&keypair, password, &config).unwrap();
let result = recover_key_encrypted(&backup, wrong_password);
assert!(result.is_err());
}
#[test]
fn test_backup_share_serialization() {
let keypair = KeyPair::generate();
let config = BackupConfig::new(2, 3);
let shares = backup_key_shamir(&keypair, &config).unwrap();
let bytes = shares[0].to_bytes().unwrap();
let recovered_share = BackupShare::from_bytes(&bytes).unwrap();
assert_eq!(shares[0].index, recovered_share.index);
assert!(recovered_share.verify_integrity());
}
#[test]
fn test_encrypted_backup_serialization() {
let keypair = KeyPair::generate();
let password = "test_password";
let config = BackupConfig::new(1, 1);
let backup = backup_key_encrypted(&keypair, password, &config).unwrap();
let bytes = backup.to_bytes().unwrap();
let recovered_backup = EncryptedBackup::from_bytes(&bytes).unwrap();
let recovered_key = recover_key_encrypted(&recovered_backup, password).unwrap();
assert_eq!(keypair.public_key(), recovered_key.public_key());
}
#[test]
fn test_generic_secret_shamir_backup() {
let secret = b"my secret data that needs backup";
let config = BackupConfig::new(2, 4).with_key_type(KeyType::GenericSecret);
let shares = backup_secret_shamir(secret, &config).unwrap();
assert_eq!(shares.len(), 4);
let recovered = recover_secret_shamir(&shares[0..2]).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
let recovered = recover_secret_shamir(&shares[1..4]).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
}
#[test]
fn test_generic_secret_encrypted_backup() {
let secret = b"confidential data";
let password = "strong_password";
let config = BackupConfig::new(1, 1);
let backup = backup_secret_encrypted(secret, password, &config).unwrap();
let recovered = recover_secret_encrypted(&backup, password).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
}
#[test]
fn test_invalid_threshold_config() {
let config = BackupConfig::new(0, 5);
assert!(config.validate().is_err());
let config = BackupConfig::new(6, 5);
assert!(config.validate().is_err());
let config = BackupConfig::new(128, 256);
assert!(config.validate().is_err());
}
#[test]
fn test_backup_config_builder() {
let config = BackupConfig::new(3, 5)
.with_label("main-key")
.with_description("Primary signing key")
.with_key_type(KeyType::SigningKey)
.with_version(2);
assert_eq!(config.label, Some("main-key".to_string()));
assert_eq!(config.description, Some("Primary signing key".to_string()));
assert_eq!(config.key_type, KeyType::SigningKey);
assert_eq!(config.version, 2);
}
#[test]
fn test_share_integrity_verification() {
let keypair = KeyPair::generate();
let config = BackupConfig::new(2, 3);
let shares = backup_key_shamir(&keypair, &config).unwrap();
for share in &shares {
assert!(share.verify_integrity());
}
let mut corrupted = shares[0].clone();
corrupted.share_data[0] ^= 0xFF;
assert!(!corrupted.verify_integrity());
}
#[test]
fn test_different_passwords_different_ciphertexts() {
let keypair = KeyPair::generate();
let config = BackupConfig::new(1, 1);
let backup1 = backup_key_encrypted(&keypair, "password1", &config).unwrap();
let backup2 = backup_key_encrypted(&keypair, "password2", &config).unwrap();
assert_ne!(backup1.ciphertext, backup2.ciphertext);
assert_ne!(backup1.salt, backup2.salt);
}
#[test]
fn test_empty_shares_recovery() {
let shares: Vec<BackupShare> = vec![];
let result = recover_key_shamir(&shares);
assert!(result.is_err());
let result = recover_secret_shamir(&shares);
assert!(result.is_err());
}
}