use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
XChaCha20Poly1305, Key, XNonce,
};
use hkdf::Hkdf;
use sha2::{Digest, Sha256, Sha512};
use crate::{Error, Result};
const AUTH_KEY_INFO: &[u8] = b"totalreclaw-auth-key-v1";
const ENCRYPTION_KEY_INFO: &[u8] = b"totalreclaw-encryption-key-v1";
const DEDUP_KEY_INFO: &[u8] = b"openmemory-dedup-v1";
const LSH_SEED_INFO: &[u8] = b"openmemory-lsh-seed-v1";
const NONCE_LENGTH: usize = 24;
const TAG_LENGTH: usize = 16;
#[derive(Clone, Debug)]
pub struct DerivedKeys {
pub auth_key: [u8; 32],
pub encryption_key: [u8; 32],
pub dedup_key: [u8; 32],
pub salt: [u8; 32],
}
fn mnemonic_to_seed(mnemonic: &str) -> Result<[u8; 64]> {
let trimmed = mnemonic.trim();
bip39::Mnemonic::parse(trimmed).map_err(|e| {
Error::InvalidMnemonic(format!("invalid BIP-39 mnemonic: {}", e))
})?;
pbkdf2_seed(trimmed)
}
fn mnemonic_to_seed_lenient(mnemonic: &str) -> Result<[u8; 64]> {
let trimmed = mnemonic.trim();
let words: Vec<&str> = trimmed.split_whitespace().collect();
if words.len() != 12 && words.len() != 24 {
return Err(Error::InvalidMnemonic(format!(
"expected 12 or 24 words, got {}",
words.len()
)));
}
let lang = bip39::Language::English;
for word in &words {
if lang.find_word(word).is_none() {
return Err(Error::InvalidMnemonic(format!(
"word '{}' not in BIP-39 English wordlist",
word
)));
}
}
pbkdf2_seed(trimmed)
}
fn pbkdf2_seed(mnemonic: &str) -> Result<[u8; 64]> {
let salt = b"mnemonic";
let mut seed = [0u8; 64];
pbkdf2::pbkdf2_hmac::<Sha512>(mnemonic.as_bytes(), salt, 2048, &mut seed);
Ok(seed)
}
pub fn derive_keys_from_mnemonic(mnemonic: &str) -> Result<DerivedKeys> {
let seed = mnemonic_to_seed(mnemonic)?;
derive_keys_from_seed(&seed)
}
pub fn derive_keys_from_mnemonic_lenient(mnemonic: &str) -> Result<DerivedKeys> {
let seed = mnemonic_to_seed_lenient(mnemonic)?;
derive_keys_from_seed(&seed)
}
fn derive_keys_from_seed(seed: &[u8; 64]) -> Result<DerivedKeys> {
let mut salt = [0u8; 32];
salt.copy_from_slice(&seed[..32]);
let auth_key = hkdf_sha256(seed, &salt, AUTH_KEY_INFO)?;
let encryption_key = hkdf_sha256(seed, &salt, ENCRYPTION_KEY_INFO)?;
let dedup_key = hkdf_sha256(seed, &salt, DEDUP_KEY_INFO)?;
Ok(DerivedKeys {
auth_key,
encryption_key,
dedup_key,
salt,
})
}
pub fn mnemonic_to_seed_bytes(mnemonic: &str) -> Result<[u8; 64]> {
mnemonic_to_seed(mnemonic)
}
pub fn derive_lsh_seed(mnemonic: &str, salt: &[u8; 32]) -> Result<[u8; 32]> {
let seed = mnemonic_to_seed(mnemonic)?;
hkdf_sha256(&seed, salt, LSH_SEED_INFO)
}
pub fn compute_auth_key_hash(auth_key: &[u8; 32]) -> String {
let hash = Sha256::digest(auth_key);
hex::encode(hash)
}
fn hkdf_sha256(ikm: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
let hk = Hkdf::<Sha256>::new(Some(salt), ikm);
let mut okm = [0u8; 32];
hk.expand(info, &mut okm)
.map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
Ok(okm)
}
pub fn encrypt(plaintext: &str, encryption_key: &[u8; 32]) -> Result<String> {
let nonce_bytes: [u8; NONCE_LENGTH] = rand::random();
encrypt_with_nonce(plaintext, encryption_key, &nonce_bytes)
}
pub fn encrypt_with_nonce(
plaintext: &str,
encryption_key: &[u8; 32],
nonce: &[u8; NONCE_LENGTH],
) -> Result<String> {
let key = Key::from_slice(encryption_key);
let cipher = XChaCha20Poly1305::new(key);
let xnonce = XNonce::from_slice(nonce);
let ciphertext_with_tag = cipher
.encrypt(xnonce, Payload { msg: plaintext.as_bytes(), aad: b"" })
.map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 encrypt failed: {}", e)))?;
let ct_len = ciphertext_with_tag.len() - TAG_LENGTH;
let ciphertext = &ciphertext_with_tag[..ct_len];
let tag = &ciphertext_with_tag[ct_len..];
let mut combined = Vec::with_capacity(NONCE_LENGTH + TAG_LENGTH + ct_len);
combined.extend_from_slice(nonce);
combined.extend_from_slice(tag);
combined.extend_from_slice(ciphertext);
use base64::Engine;
Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
}
pub fn decrypt(encrypted_base64: &str, encryption_key: &[u8; 32]) -> Result<String> {
use base64::Engine;
let combined = base64::engine::general_purpose::STANDARD
.decode(encrypted_base64)
.map_err(|e| Error::Crypto(format!("base64 decode failed: {}", e)))?;
if combined.len() < NONCE_LENGTH + TAG_LENGTH {
return Err(Error::Crypto("Encrypted data too short".into()));
}
let nonce = &combined[..NONCE_LENGTH];
let tag = &combined[NONCE_LENGTH..NONCE_LENGTH + TAG_LENGTH];
let ciphertext = &combined[NONCE_LENGTH + TAG_LENGTH..];
let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + TAG_LENGTH);
ct_with_tag.extend_from_slice(ciphertext);
ct_with_tag.extend_from_slice(tag);
let key = Key::from_slice(encryption_key);
let cipher = XChaCha20Poly1305::new(key);
let xnonce = XNonce::from_slice(nonce);
let plaintext_bytes = cipher
.decrypt(xnonce, Payload { msg: &ct_with_tag, aad: b"" })
.map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 decrypt failed: {}", e)))?;
String::from_utf8(plaintext_bytes)
.map_err(|e| Error::Crypto(format!("UTF-8 decode failed: {}", e)))
}
pub fn derive_random_bytes(seed: &[u8], base_info: &str, length: usize) -> Result<Vec<u8>> {
const MAX_HKDF_OUTPUT: usize = 255 * 32;
let mut result = vec![0u8; length];
let mut offset = 0;
let mut block_index = 0;
while offset < length {
let remaining = length - offset;
let chunk_len = remaining.min(MAX_HKDF_OUTPUT);
let info = format!("{}_block_{}", base_info, block_index);
let hk = Hkdf::<Sha256>::new(Some(&[]), seed);
hk.expand(info.as_bytes(), &mut result[offset..offset + chunk_len])
.map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
offset += chunk_len;
block_index += 1;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_derivation_parity() {
let fixture: serde_json::Value = serde_json::from_str(
include_str!("../tests/fixtures/crypto_vectors.json"),
)
.unwrap();
let kd = &fixture["key_derivation"];
let mnemonic = kd["mnemonic"].as_str().unwrap();
let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
assert_eq!(hex::encode(keys.salt), kd["salt_hex"].as_str().unwrap());
assert_eq!(hex::encode(keys.auth_key), kd["auth_key_hex"].as_str().unwrap());
assert_eq!(
hex::encode(keys.encryption_key),
kd["encryption_key_hex"].as_str().unwrap()
);
assert_eq!(hex::encode(keys.dedup_key), kd["dedup_key_hex"].as_str().unwrap());
let hash = compute_auth_key_hash(&keys.auth_key);
assert_eq!(hash, kd["auth_key_hash"].as_str().unwrap());
}
#[test]
fn test_bip39_seed_parity() {
let fixture: serde_json::Value = serde_json::from_str(
include_str!("../tests/fixtures/crypto_vectors.json"),
)
.unwrap();
let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
let expected_seed_hex = fixture["key_derivation"]["bip39_seed_hex"].as_str().unwrap();
let seed = mnemonic_to_seed(mnemonic).unwrap();
assert_eq!(hex::encode(seed), expected_seed_hex);
}
#[test]
fn test_lsh_seed_parity() {
let fixture: serde_json::Value = serde_json::from_str(
include_str!("../tests/fixtures/crypto_vectors.json"),
)
.unwrap();
let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
let lsh_seed = derive_lsh_seed(mnemonic, &keys.salt).unwrap();
assert_eq!(
hex::encode(lsh_seed),
fixture["lsh"]["lsh_seed_hex"].as_str().unwrap()
);
}
#[test]
fn test_xchacha_fixed_nonce_parity() {
let fixture: serde_json::Value = serde_json::from_str(
include_str!("../tests/fixtures/crypto_vectors.json"),
)
.unwrap();
let xc = &fixture["xchacha20"];
let key_hex = xc["encryption_key_hex"].as_str().unwrap();
let key_bytes = hex::decode(key_hex).unwrap();
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes);
let nonce_hex = xc["fixed_nonce_hex"].as_str().unwrap();
let nonce_bytes = hex::decode(nonce_hex).unwrap();
let mut nonce = [0u8; 24];
nonce.copy_from_slice(&nonce_bytes);
let plaintext = xc["plaintext"].as_str().unwrap();
let expected_b64 = xc["fixed_nonce_encrypted_base64"].as_str().unwrap();
let encrypted = encrypt_with_nonce(plaintext, &key, &nonce).unwrap();
assert_eq!(encrypted, expected_b64);
let decrypted = decrypt(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_xchacha_round_trip() {
let keys = derive_keys_from_mnemonic(
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
)
.unwrap();
let plaintext = "Hello, TotalReclaw!";
let encrypted = encrypt(plaintext, &keys.encryption_key).unwrap();
let decrypted = decrypt(&encrypted, &keys.encryption_key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_lenient_accepts_valid_mnemonic() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let strict = derive_keys_from_mnemonic(mnemonic).unwrap();
let lenient = derive_keys_from_mnemonic_lenient(mnemonic).unwrap();
assert_eq!(strict.auth_key, lenient.auth_key);
assert_eq!(strict.encryption_key, lenient.encryption_key);
assert_eq!(strict.dedup_key, lenient.dedup_key);
assert_eq!(strict.salt, lenient.salt);
}
#[test]
fn test_lenient_rejects_invalid_words() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon xyzzy";
let result = derive_keys_from_mnemonic_lenient(mnemonic);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("xyzzy"));
}
#[test]
fn test_lenient_rejects_wrong_word_count() {
let mnemonic = "abandon abandon abandon"; let result = derive_keys_from_mnemonic_lenient(mnemonic);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("expected 12 or 24"));
}
#[test]
fn test_strict_rejects_bad_checksum() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
let result = derive_keys_from_mnemonic(mnemonic);
assert!(result.is_err());
}
#[test]
fn test_lenient_accepts_bad_checksum() {
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
let result = derive_keys_from_mnemonic_lenient(mnemonic);
assert!(result.is_ok());
}
}