use bsv::primitives::hash::pbkdf2_hmac_sha512;
use bsv::primitives::private_key::PrivateKey;
use bsv::primitives::random::random_bytes;
use bsv::primitives::symmetric_key::SymmetricKey;
use crate::WalletError;
use super::types::StateSnapshot;
pub const PBKDF2_NUM_ROUNDS: u32 = 7777;
const SNAPSHOT_VERSION_1: u8 = 1;
const SNAPSHOT_VERSION_2: u8 = 2;
const SNAPSHOT_KEY_SIZE: usize = 32;
const PROFILE_ID_SIZE: usize = 16;
const ROOT_KEY_SIZE: usize = 32;
pub fn derive_key_from_password(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
pbkdf2_hmac_sha512(password, salt, iterations, 64)
}
pub fn xor_keys(key1: &[u8], key2: &[u8]) -> Vec<u8> {
debug_assert_eq!(
key1.len(),
key2.len(),
"xor_keys: arrays must have equal length ({} vs {})",
key1.len(),
key2.len()
);
key1.iter().zip(key2.iter()).map(|(a, b)| a ^ b).collect()
}
pub fn derive_identity_key(root_key: &[u8]) -> Result<String, WalletError> {
if root_key.len() < 32 {
return Err(WalletError::Internal(format!(
"Root key too short for identity derivation: {} bytes, need at least 32",
root_key.len()
)));
}
let pk = PrivateKey::from_bytes(&root_key[..32]).map_err(|e| {
WalletError::Internal(format!("Failed to create PrivateKey from root key: {}", e))
})?;
let pub_key = pk.to_public_key();
Ok(pub_key.to_der_hex())
}
pub fn restore_root_key_from_snapshot(snapshot: &StateSnapshot) -> Result<Vec<u8>, WalletError> {
if snapshot.presentation_key.is_some() {
return Err(WalletError::Internal(
"Root key restoration from snapshot requires UMP token lookup and password/recovery key"
.to_string(),
));
}
Err(WalletError::Internal(
"No presentation key in snapshot; re-authentication required".to_string(),
))
}
pub fn restore_root_key_from_snapshot_bytes(
snapshot_bytes: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), WalletError> {
if snapshot_bytes.is_empty() {
return Err(WalletError::Internal(
"Snapshot bytes are empty".to_string(),
));
}
let version = snapshot_bytes[0];
match version {
SNAPSHOT_VERSION_1 => restore_v1(snapshot_bytes),
SNAPSHOT_VERSION_2 => restore_v2(snapshot_bytes),
_ => Err(WalletError::Internal(format!(
"Unsupported snapshot version: {}",
version
))),
}
}
fn restore_v1(data: &[u8]) -> Result<(Vec<u8>, Vec<u8>), WalletError> {
let min_len = 1 + SNAPSHOT_KEY_SIZE; if data.len() <= min_len {
return Err(WalletError::Internal(format!(
"V1 snapshot too short: {} bytes, need more than {}",
data.len(),
min_len
)));
}
let snapshot_key_bytes = &data[1..1 + SNAPSHOT_KEY_SIZE];
let encrypted_payload = &data[1 + SNAPSHOT_KEY_SIZE..];
let sym_key = SymmetricKey::from_bytes(snapshot_key_bytes)
.map_err(|e| WalletError::Internal(format!("Invalid snapshot key: {}", e)))?;
let decrypted = sym_key
.decrypt(encrypted_payload)
.map_err(|e| WalletError::Internal(format!("Snapshot decryption failed: {}", e)))?;
if decrypted.len() < ROOT_KEY_SIZE {
return Err(WalletError::Internal(format!(
"Decrypted payload too short for root key: {} bytes, need at least {}",
decrypted.len(),
ROOT_KEY_SIZE
)));
}
let root_key = decrypted[..ROOT_KEY_SIZE].to_vec();
let active_profile_id = vec![0u8; PROFILE_ID_SIZE];
Ok((root_key, active_profile_id))
}
fn restore_v2(data: &[u8]) -> Result<(Vec<u8>, Vec<u8>), WalletError> {
let min_len = 1 + SNAPSHOT_KEY_SIZE + PROFILE_ID_SIZE;
if data.len() <= min_len {
return Err(WalletError::Internal(format!(
"V2 snapshot too short: {} bytes, need more than {}",
data.len(),
min_len
)));
}
let snapshot_key_bytes = &data[1..1 + SNAPSHOT_KEY_SIZE];
let active_profile_id =
data[1 + SNAPSHOT_KEY_SIZE..1 + SNAPSHOT_KEY_SIZE + PROFILE_ID_SIZE].to_vec();
let encrypted_payload = &data[1 + SNAPSHOT_KEY_SIZE + PROFILE_ID_SIZE..];
let sym_key = SymmetricKey::from_bytes(snapshot_key_bytes)
.map_err(|e| WalletError::Internal(format!("Invalid snapshot key: {}", e)))?;
let decrypted = sym_key
.decrypt(encrypted_payload)
.map_err(|e| WalletError::Internal(format!("Snapshot decryption failed: {}", e)))?;
if decrypted.len() < ROOT_KEY_SIZE {
return Err(WalletError::Internal(format!(
"Decrypted payload too short for root key: {} bytes, need at least {}",
decrypted.len(),
ROOT_KEY_SIZE
)));
}
let root_key = decrypted[..ROOT_KEY_SIZE].to_vec();
Ok((root_key, active_profile_id))
}
fn encode_varint(mut value: usize) -> Vec<u8> {
let mut buf = Vec::new();
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf.push(byte);
if value == 0 {
break;
}
}
buf
}
pub fn save_snapshot(
root_key: &[u8],
active_profile_id: &[u8],
token_bytes: &[u8],
) -> Result<Vec<u8>, WalletError> {
if root_key.len() != ROOT_KEY_SIZE {
return Err(WalletError::Internal(format!(
"Root key must be {} bytes, got {}",
ROOT_KEY_SIZE,
root_key.len()
)));
}
if active_profile_id.len() != PROFILE_ID_SIZE {
return Err(WalletError::Internal(format!(
"Active profile ID must be {} bytes, got {}",
PROFILE_ID_SIZE,
active_profile_id.len()
)));
}
let varint = encode_varint(token_bytes.len());
let mut payload = Vec::with_capacity(ROOT_KEY_SIZE + varint.len() + token_bytes.len());
payload.extend_from_slice(root_key);
payload.extend_from_slice(&varint);
payload.extend_from_slice(token_bytes);
let snapshot_key_bytes = random_bytes(SNAPSHOT_KEY_SIZE);
let sym_key = SymmetricKey::from_bytes(&snapshot_key_bytes)
.map_err(|e| WalletError::Internal(format!("Failed to create snapshot key: {}", e)))?;
let encrypted = sym_key
.encrypt(&payload)
.map_err(|e| WalletError::Internal(format!("Snapshot encryption failed: {}", e)))?;
let mut result = Vec::with_capacity(1 + SNAPSHOT_KEY_SIZE + PROFILE_ID_SIZE + encrypted.len());
result.push(SNAPSHOT_VERSION_2);
result.extend_from_slice(&snapshot_key_bytes);
result.extend_from_slice(active_profile_id);
result.extend_from_slice(&encrypted);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xor_keys() {
let a = vec![0xAA, 0xBB, 0xCC, 0xDD];
let b = vec![0x11, 0x22, 0x33, 0x44];
let result = xor_keys(&a, &b);
assert_eq!(result, vec![0xBB, 0x99, 0xFF, 0x99]);
let zeros = xor_keys(&a, &a);
assert_eq!(zeros, vec![0x00, 0x00, 0x00, 0x00]);
let z = vec![0x00, 0x00, 0x00, 0x00];
assert_eq!(xor_keys(&a, &z), a);
}
#[test]
fn test_derive_key_from_password() {
let password = b"test-password";
let salt = b"test-salt-value";
let result = derive_key_from_password(password, salt, 100);
assert_eq!(
result.len(),
64,
"PBKDF2 output should be 64 bytes, got {}",
result.len()
);
let result2 = derive_key_from_password(password, salt, 100);
assert_eq!(result, result2, "PBKDF2 should be deterministic");
let result3 = derive_key_from_password(password, b"different-salt", 100);
assert_ne!(
result, result3,
"Different salt should produce different key"
);
}
#[test]
fn test_save_restore_roundtrip_v2() {
let root_key = vec![0xAB; 32];
let profile_id = vec![0xCD; 16];
let token_bytes = b"serialized-ump-token-data";
let snapshot = save_snapshot(&root_key, &profile_id, token_bytes)
.expect("save_snapshot should succeed");
assert_eq!(snapshot[0], SNAPSHOT_VERSION_2);
let (restored_root, restored_profile) =
restore_root_key_from_snapshot_bytes(&snapshot).expect("restore should succeed");
assert_eq!(restored_root, root_key, "root key should round-trip");
assert_eq!(restored_profile, profile_id, "profile ID should round-trip");
}
#[test]
fn test_save_restore_roundtrip_empty_token() {
let root_key = vec![0xFF; 32];
let profile_id = vec![0x00; 16];
let token_bytes = b"";
let snapshot = save_snapshot(&root_key, &profile_id, token_bytes)
.expect("save_snapshot should succeed");
let (restored_root, restored_profile) =
restore_root_key_from_snapshot_bytes(&snapshot).expect("restore should succeed");
assert_eq!(restored_root, root_key);
assert_eq!(restored_profile, profile_id);
}
#[test]
fn test_save_restore_roundtrip_large_token() {
let root_key = vec![0x42; 32];
let profile_id = vec![0x13; 16];
let token_bytes = vec![0x77; 1024];
let snapshot = save_snapshot(&root_key, &profile_id, &token_bytes)
.expect("save_snapshot should succeed");
let (restored_root, restored_profile) =
restore_root_key_from_snapshot_bytes(&snapshot).expect("restore should succeed");
assert_eq!(restored_root, root_key);
assert_eq!(restored_profile, profile_id);
}
#[test]
fn test_v2_format_parsing() {
let root_key = vec![0x11; 32];
let profile_id = vec![0x22; 16];
let mut payload = Vec::new();
payload.extend_from_slice(&root_key);
payload.push(0x00);
let snapshot_key_bytes = random_bytes(32);
let sym_key = SymmetricKey::from_bytes(&snapshot_key_bytes).unwrap();
let encrypted = sym_key.encrypt(&payload).unwrap();
let mut snapshot = Vec::new();
snapshot.push(SNAPSHOT_VERSION_2);
snapshot.extend_from_slice(&snapshot_key_bytes);
snapshot.extend_from_slice(&profile_id);
snapshot.extend_from_slice(&encrypted);
let (restored_root, restored_profile) =
restore_root_key_from_snapshot_bytes(&snapshot).expect("V2 parsing should succeed");
assert_eq!(restored_root, root_key);
assert_eq!(restored_profile, profile_id);
}
#[test]
fn test_v1_format_parsing() {
let root_key = vec![0x33; 32];
let mut payload = Vec::new();
payload.extend_from_slice(&root_key);
payload.push(0x00);
let snapshot_key_bytes = random_bytes(32);
let sym_key = SymmetricKey::from_bytes(&snapshot_key_bytes).unwrap();
let encrypted = sym_key.encrypt(&payload).unwrap();
let mut snapshot = Vec::new();
snapshot.push(SNAPSHOT_VERSION_1);
snapshot.extend_from_slice(&snapshot_key_bytes);
snapshot.extend_from_slice(&encrypted);
let (restored_root, restored_profile) =
restore_root_key_from_snapshot_bytes(&snapshot).expect("V1 parsing should succeed");
assert_eq!(restored_root, root_key);
assert_eq!(restored_profile, vec![0u8; 16]);
}
#[test]
fn test_unsupported_version() {
let data = vec![0xFF, 0x00, 0x00]; let result = restore_root_key_from_snapshot_bytes(&data);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Unsupported snapshot version"),
"Error: {}",
err_msg
);
}
#[test]
fn test_empty_snapshot_bytes() {
let result = restore_root_key_from_snapshot_bytes(&[]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("empty"), "Error: {}", err_msg);
}
#[test]
fn test_truncated_v2_snapshot() {
let data = vec![SNAPSHOT_VERSION_2; 10];
let result = restore_root_key_from_snapshot_bytes(&data);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("too short"), "Error: {}", err_msg);
}
#[test]
fn test_truncated_v1_snapshot() {
let data = vec![SNAPSHOT_VERSION_1; 10];
let result = restore_root_key_from_snapshot_bytes(&data);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("too short"), "Error: {}", err_msg);
}
#[test]
fn test_save_snapshot_invalid_root_key_size() {
let result = save_snapshot(&[0u8; 16], &[0u8; 16], &[]);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Root key must be 32 bytes"));
}
#[test]
fn test_save_snapshot_invalid_profile_id_size() {
let result = save_snapshot(&[0u8; 32], &[0u8; 8], &[]);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Active profile ID must be 16 bytes"));
}
#[test]
fn test_encode_varint_small() {
assert_eq!(encode_varint(0), vec![0x00]);
assert_eq!(encode_varint(1), vec![0x01]);
assert_eq!(encode_varint(127), vec![0x7F]);
}
#[test]
fn test_encode_varint_multi_byte() {
assert_eq!(encode_varint(128), vec![0x80, 0x01]);
assert_eq!(encode_varint(300), vec![0xAC, 0x02]);
}
#[test]
fn test_two_snapshots_differ() {
let root_key = vec![0x42; 32];
let profile_id = vec![0x00; 16];
let s1 = save_snapshot(&root_key, &profile_id, &[]).unwrap();
let s2 = save_snapshot(&root_key, &profile_id, &[]).unwrap();
assert_ne!(s1, s2, "Two snapshots should differ due to random keys");
let (r1, _) = restore_root_key_from_snapshot_bytes(&s1).unwrap();
let (r2, _) = restore_root_key_from_snapshot_bytes(&s2).unwrap();
assert_eq!(r1, root_key);
assert_eq!(r2, root_key);
}
}