use std::{
fs::File,
io::{BufWriter, Read, Write},
path::PathBuf,
};
use anyhow::{Error, Result};
use argon2::{Argon2, PasswordHasher, password_hash::phc::SaltString};
use aws_lc_rs::{
aead::{AES_256_GCM_SIV, Aad, Nonce, RandomizedNonceKey},
agreement::{ECDH_P256, ECDH_P384, PrivateKey, PublicKey, X25519},
cipher::AES_256_KEY_LEN,
digest::SHA512_OUTPUT_LEN,
encoding::{AsBigEndian as _, Curve25519SeedBin, EcPrivateKeyBin},
hkdf::{HKDF_SHA512, Salt},
rand::fill,
};
#[cfg(feature = "unstable")]
use aws_lc_rs::{
encoding::AsRawBytes as _,
signature::KeyPair as _,
unstable::signature::{
ML_DSA_44_SIGNING, ML_DSA_65_SIGNING, ML_DSA_87_SIGNING, PqdsaKeyPair,
PqdsaSigningAlgorithm,
},
};
use base64::{Engine as _, engine::general_purpose::STANDARD};
use bytes::{Buf as _, BytesMut};
use getset::Getters;
use whoami::{hostname, username};
use crate::{KexMode, MoshpitError};
pub(crate) mod pk;
const KEY_HEADER: &[u8] = b"moshpit-key-v1";
pub const KEY_ALGORITHM_X25519: &str = "X25519";
pub const KEY_ALGORITHM_P384: &str = "P384";
pub const KEY_ALGORITHM_P256: &str = "P256";
#[cfg(feature = "unstable")]
pub const KEY_ALGORITHM_ML_DSA_44: &str = "ML-DSA-44";
#[cfg(feature = "unstable")]
pub const KEY_ALGORITHM_ML_DSA_65: &str = "ML-DSA-65";
#[cfg(feature = "unstable")]
pub const KEY_ALGORITHM_ML_DSA_87: &str = "ML-DSA-87";
const NONE_CIPHER: &str = "none";
const NONE_KDF: &str = "none";
const KEY_CIPHER: &str = "aes-256-gcm-siv";
const HKDF_INFO: &[&[u8]] = &[b"moshpit HKDF"];
#[cfg(feature = "unstable")]
fn resolve_pqdsa_signing_alg(key_alg: &str) -> Option<&'static PqdsaSigningAlgorithm> {
match key_alg {
KEY_ALGORITHM_ML_DSA_44 => Some(&ML_DSA_44_SIGNING),
KEY_ALGORITHM_ML_DSA_65 => Some(&ML_DSA_65_SIGNING),
KEY_ALGORITHM_ML_DSA_87 => Some(&ML_DSA_87_SIGNING),
_ => None,
}
}
#[cfg(feature = "unstable")]
fn is_pqdsa_key_algorithm(key_alg: &str) -> bool {
resolve_pqdsa_signing_alg(key_alg).is_some()
}
#[cfg(not(feature = "unstable"))]
fn is_pqdsa_key_algorithm(_key_alg: &str) -> bool {
false
}
fn is_supported_key_algorithm(key_alg: &str) -> bool {
matches!(
key_alg,
KEY_ALGORITHM_X25519 | KEY_ALGORITHM_P384 | KEY_ALGORITHM_P256
) || is_pqdsa_key_algorithm(key_alg)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum AEADCipher {
None,
Aes256GcmSiv,
}
impl AEADCipher {
#[must_use]
pub fn as_str(&self) -> &str {
match self {
AEADCipher::None => NONE_CIPHER,
AEADCipher::Aes256GcmSiv => KEY_CIPHER,
}
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
self.as_str().as_bytes()
}
}
impl TryFrom<&str> for AEADCipher {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
TryFrom::try_from(value.as_bytes())
}
}
impl TryFrom<&[u8]> for AEADCipher {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self> {
match value {
b"none" => Ok(AEADCipher::None),
b"aes-256-gcm-siv" => Ok(AEADCipher::Aes256GcmSiv),
_ => Err(MoshpitError::UnsupportedAeadCipher.into()),
}
}
}
#[derive(Debug, Getters)]
#[getset(get = "pub")]
pub struct UnencryptedKeyPair {
private_key: PrivateKey,
public_key: PublicKey,
}
impl UnencryptedKeyPair {
#[must_use]
pub fn take(self) -> (PrivateKey, PublicKey) {
(self.private_key, self.public_key)
}
}
#[derive(Debug, Getters)]
#[getset(get = "pub")]
pub struct EncryptedKeyPair {
kdf: String,
public_key: Vec<u8>,
key_algorithm: String,
salt_bytes: Vec<u8>,
nonce_bytes: Vec<u8>,
encrypted_private_key_bytes: Vec<u8>,
}
#[derive(Debug, Getters)]
#[getset(get = "pub")]
pub struct IdentityKeyPair {
key_algorithm: String,
public_key: Vec<u8>,
private_key: Vec<u8>,
}
#[derive(Debug, Getters)]
#[getset(get = "pub")]
pub struct KeyPair {
private_key: String,
public_key: String,
public_key_bytes: Vec<u8>,
}
impl KeyPair {
pub fn default_key_path_ext(mode: KexMode, key_alg: &str) -> Result<(PathBuf, &'static str)> {
let base_dir = dirs2::home_dir().ok_or(MoshpitError::HomeDir)?.join(".mp");
let stem = key_alg.to_lowercase().replace('-', "_");
Ok(match mode {
KexMode::Client => (base_dir.join(format!("id_{stem}")), "pub"),
KexMode::Server(_socket_addr) => (base_dir.join(format!("mps_host_{stem}_key")), "pub"),
})
}
pub fn generate_key_pair(
passphrase_opt: Option<&String>,
mode: KexMode,
key_alg: &str,
) -> Result<Self> {
if matches!(mode, KexMode::Client) && passphrase_opt.is_none_or(String::is_empty) {
return Err(anyhow::anyhow!(
"A non-empty passphrase is required to protect the private key"
));
}
#[cfg(feature = "unstable")]
if let Some(alg) = resolve_pqdsa_signing_alg(key_alg) {
let key_pair = PqdsaKeyPair::generate(alg)?;
let public_key = key_pair.public_key().as_ref();
let private_key = key_pair.private_key().as_raw_bytes()?;
let (public_key_bytes, public_key_encoded) = generate_public_key(key_alg, public_key)?;
let mut priv_key_bytes = private_key.as_ref().to_vec();
let private_key_encoded =
generate_private_key(&mut priv_key_bytes, public_key, passphrase_opt, key_alg)?;
return Ok(KeyPair {
private_key: private_key_encoded,
public_key: public_key_encoded,
public_key_bytes,
});
}
let alg = match key_alg {
KEY_ALGORITHM_X25519 => &X25519,
KEY_ALGORITHM_P384 => &ECDH_P384,
KEY_ALGORITHM_P256 => &ECDH_P256,
_ => return Err(anyhow::anyhow!("Unknown key algorithm: {key_alg}")),
};
let priv_key = PrivateKey::generate(alg)?;
let public_key = priv_key.compute_public_key()?;
let (public_key_bytes, public_key_encoded) =
generate_public_key(key_alg, public_key.as_ref())?;
let mut priv_key_bytes = if key_alg == KEY_ALGORITHM_X25519 {
let bytes: Curve25519SeedBin<'_> = priv_key.as_be_bytes()?;
bytes.as_ref().to_vec()
} else {
let bytes: EcPrivateKeyBin<'_> = priv_key.as_be_bytes()?;
bytes.as_ref().to_vec()
};
let private_key_encoded = generate_private_key(
&mut priv_key_bytes,
public_key.as_ref(),
passphrase_opt,
key_alg,
)?;
Ok(KeyPair {
private_key: private_key_encoded,
public_key: public_key_encoded,
public_key_bytes,
})
}
pub fn write_private_key<T>(&self, writer: &mut T) -> Result<()>
where
T: Write,
{
let mut buf_writer = BufWriter::new(writer);
buf_writer.write_all(self.private_key.as_bytes())?;
Ok(())
}
pub fn write_public_key<T>(&self, writer: &mut T) -> Result<()>
where
T: Write,
{
let mut pub_buf_writer = BufWriter::new(writer);
pub_buf_writer.write_all(b"moshpit ")?;
pub_buf_writer.write_all(self.public_key.as_bytes())?;
let username = username().unwrap_or("unknown-user".to_string());
let hostname = hostname().unwrap_or("unknown-host".to_string());
pub_buf_writer.write_all(format!(" {username}@{hostname}").as_bytes())?;
Ok(())
}
pub fn fingerprint(&self) -> Result<String> {
pk::fingerprint(&self.public_key_bytes)
}
#[must_use]
pub fn randomart(&self) -> String {
pk::randomart(&self.public_key_bytes)
}
}
fn add_key_alg(key: &mut Vec<u8>, alg: &str) -> Result<()> {
key.extend_from_slice(&as_be_bytes(alg.len())?);
key.extend_from_slice(alg.as_bytes());
Ok(())
}
fn add_cipher_and_kdf(key: &mut Vec<u8>, cipher: &str, kdf: &str) -> Result<()> {
key.extend_from_slice(&as_be_bytes(cipher.len())?);
key.extend_from_slice(cipher.as_bytes());
key.extend_from_slice(&as_be_bytes(kdf.len())?);
key.extend_from_slice(kdf.as_bytes());
Ok(())
}
fn generate_public_key(alg: &str, public_key: &[u8]) -> Result<(Vec<u8>, String)> {
let mut public_key_bytes = vec![];
add_key_alg(&mut public_key_bytes, alg)?;
public_key_bytes.extend_from_slice(&as_be_bytes(public_key.len())?);
public_key_bytes.extend_from_slice(public_key);
let encoded = STANDARD.encode(&public_key_bytes);
Ok((public_key_bytes, encoded))
}
fn generate_private_key(
private_key: &mut Vec<u8>,
public_key: &[u8],
passphrase_opt: Option<&String>,
alg: &str,
) -> Result<String> {
let mut private_key_bytes = vec![];
private_key_bytes.extend_from_slice(KEY_HEADER);
let passphrase_hash_opt = generate_passphrase_hash(passphrase_opt);
if let Some((passphrase, passphrase_hash)) = passphrase_opt.zip(passphrase_hash_opt) {
setup_encrypted_private_key(
&mut private_key_bytes,
private_key,
public_key,
passphrase,
&passphrase_hash,
alg,
)?;
} else {
setup_unencrypted_private_key(&mut private_key_bytes, private_key, public_key, alg)?;
}
Ok(STANDARD.encode(&private_key_bytes))
}
fn setup_encrypted_private_key(
private_key_bytes: &mut Vec<u8>,
private_key: &mut Vec<u8>,
public_key: &[u8],
passphrase: &str,
passphrase_hash: &str,
alg: &str,
) -> Result<()> {
add_cipher_and_kdf(private_key_bytes, KEY_CIPHER, passphrase_hash)?;
add_key_alg(private_key_bytes, alg)?;
private_key_bytes.extend_from_slice(&as_be_bytes(public_key.len())?);
private_key_bytes.extend_from_slice(public_key);
encrypt_private_key(private_key_bytes, private_key, passphrase)
}
fn setup_unencrypted_private_key(
private_key_bytes: &mut Vec<u8>,
private_key: &[u8],
public_key: &[u8],
alg: &str,
) -> Result<()> {
add_cipher_and_kdf(private_key_bytes, NONE_CIPHER, NONE_KDF)?;
add_key_alg(private_key_bytes, alg)?;
private_key_bytes.extend_from_slice(&as_be_bytes(public_key.len())?);
private_key_bytes.extend_from_slice(public_key);
private_key_bytes.extend_from_slice(&as_be_bytes(private_key.len())?);
private_key_bytes.extend_from_slice(private_key);
Ok(())
}
fn generate_passphrase_hash(passphrase_opt: Option<&String>) -> Option<String> {
passphrase_opt.and_then(|passphrase| {
let salt = SaltString::generate();
let argon2 = Argon2::default();
argon2
.hash_password_with_salt(passphrase.as_bytes(), salt.as_bytes())
.ok()
.map(|h| h.to_string())
})
}
fn encrypt_private_key(
private_key_bytes: &mut Vec<u8>,
private_key: &mut Vec<u8>,
passphrase: &str,
) -> Result<()> {
use zeroize::Zeroize;
let key_bytes = passphrase.as_bytes();
let mut salt_bytes = [0u8; SHA512_OUTPUT_LEN];
fill(&mut salt_bytes)?;
let salt = Salt::new(HKDF_SHA512, &salt_bytes);
let pseudo_random_key = salt.extract(key_bytes);
let okm_aes = pseudo_random_key.expand(HKDF_INFO, &AES_256_GCM_SIV)?;
let mut derived_key = [0u8; AES_256_KEY_LEN];
okm_aes.fill(&mut derived_key)?;
let rnk = RandomizedNonceKey::new(&AES_256_GCM_SIV, &derived_key)?;
derived_key.zeroize();
let nonce = rnk.seal_in_place_append_tag(Aad::empty(), private_key)?;
let nonce_bytes = nonce.as_ref();
private_key_bytes.extend_from_slice(&as_be_bytes(salt_bytes.len())?);
private_key_bytes.extend_from_slice(&salt_bytes);
private_key_bytes.extend_from_slice(&as_be_bytes(nonce_bytes.len())?);
private_key_bytes.extend_from_slice(nonce_bytes);
private_key_bytes.extend_from_slice(&as_be_bytes(private_key.len())?);
private_key_bytes.extend_from_slice(private_key);
Ok(())
}
pub fn decrypt_private_key(
passphrase: &str,
salt_bytes: &[u8],
nonce_bytes: &[u8],
encrypted_private_key_bytes: &mut [u8],
) -> Result<()> {
let _plaintext_len = decrypt_private_key_in_place(
passphrase,
salt_bytes,
nonce_bytes,
encrypted_private_key_bytes,
)?;
Ok(())
}
fn decrypt_private_key_to_vec(
passphrase: &str,
salt_bytes: &[u8],
nonce_bytes: &[u8],
encrypted_private_key_bytes: &[u8],
) -> Result<Vec<u8>> {
let mut private_key = encrypted_private_key_bytes.to_vec();
let plaintext_len =
decrypt_private_key_in_place(passphrase, salt_bytes, nonce_bytes, &mut private_key)?;
private_key.truncate(plaintext_len);
Ok(private_key)
}
fn decrypt_private_key_in_place(
passphrase: &str,
salt_bytes: &[u8],
nonce_bytes: &[u8],
encrypted_private_key_bytes: &mut [u8],
) -> Result<usize> {
use zeroize::Zeroize;
let key_bytes = passphrase.as_bytes();
let salt = Salt::new(HKDF_SHA512, salt_bytes);
let pseudo_random_key = salt.extract(key_bytes);
let okm_aes = pseudo_random_key.expand(HKDF_INFO, &AES_256_GCM_SIV)?;
let mut derived_key = [0u8; AES_256_KEY_LEN];
okm_aes.fill(&mut derived_key)?;
let rnk = RandomizedNonceKey::new(&AES_256_GCM_SIV, &derived_key)?;
derived_key.zeroize();
let nonce = Nonce::try_assume_unique_for_key(nonce_bytes)?;
let plaintext = rnk.open_in_place(nonce, Aad::empty(), encrypted_private_key_bytes)?;
Ok(plaintext.len())
}
fn as_be_bytes(value: usize) -> Result<[u8; 4]> {
Ok(u32::try_from(value)?.to_be_bytes())
}
pub fn load_public_key(pub_key_path: &PathBuf) -> Result<(Vec<u8>, Vec<u8>)> {
let mut buffered_reader = File::open(pub_key_path)?;
let mut file_bytes = vec![];
let _len = buffered_reader.read_to_end(&mut file_bytes)?;
let pub_key_str = String::from_utf8_lossy(&file_bytes);
let pub_key_parts: Vec<&str> = pub_key_str.split_whitespace().collect();
if pub_key_parts.len() != 3 {
return Err(MoshpitError::InvalidKeyHeader.into());
}
let pub_key_part = pub_key_parts[1].as_bytes();
let decoded = STANDARD.decode(pub_key_part)?;
let mut public_key_bytes = BytesMut::from(&decoded[..]);
let key_alg = get_val_by_len(&mut public_key_bytes)?;
let key_alg_str = std::str::from_utf8(&key_alg).map_err(|_| MoshpitError::InvalidKeyHeader)?;
if !is_supported_key_algorithm(key_alg_str) {
return Err(MoshpitError::InvalidKeyHeader.into());
}
let pub_key_bytes = get_val_by_len(&mut public_key_bytes)?;
Ok((file_bytes, pub_key_bytes.to_vec()))
}
pub fn load_private_key(
priv_key_path: &PathBuf,
) -> Result<(Option<UnencryptedKeyPair>, Option<EncryptedKeyPair>)> {
let mut buffered_reader = File::open(priv_key_path)?;
let mut file_bytes = vec![];
let _len = buffered_reader.read_to_end(&mut file_bytes)?;
let decoded = STANDARD.decode(&file_bytes)?;
let mut private_key_bytes = BytesMut::from(&decoded[..]);
let magic_key = private_key_bytes.split_to(KEY_HEADER.len());
let magic_key_bytes = magic_key.freeze();
if &magic_key_bytes[..] != KEY_HEADER {
return Err(MoshpitError::InvalidKeyHeader.into());
}
let cipher = get_val_by_len(&mut private_key_bytes)?;
let kdf = get_val_by_len(&mut private_key_bytes)?;
let key_alg = get_val_by_len(&mut private_key_bytes)?;
let key_alg_str = std::str::from_utf8(&key_alg).map_err(|_| MoshpitError::InvalidKeyHeader)?;
let agreement_alg: &aws_lc_rs::agreement::Algorithm = match key_alg_str {
KEY_ALGORITHM_X25519 => &X25519,
KEY_ALGORITHM_P384 => &ECDH_P384,
KEY_ALGORITHM_P256 => &ECDH_P256,
_ => return Err(MoshpitError::InvalidKeyHeader.into()),
};
if cipher == NONE_CIPHER.as_bytes() && kdf == NONE_KDF.as_bytes() {
let pub_key_bytes = get_val_by_len(&mut private_key_bytes)?;
let priv_key_bytes = get_val_by_len(&mut private_key_bytes)?;
let private_key = PrivateKey::from_private_key(agreement_alg, &priv_key_bytes)?;
let public_key = private_key.compute_public_key()?;
if public_key.as_ref() != pub_key_bytes.as_ref() {
return Err(MoshpitError::PublicKeyMismatch.into());
}
let unencrypted_key_pair = UnencryptedKeyPair {
private_key,
public_key,
};
Ok((Some(unencrypted_key_pair), None))
} else {
let pub_key_bytes = get_val_by_len(&mut private_key_bytes)?;
let salt_bytes = get_val_by_len(&mut private_key_bytes)?;
let nonce_bytes = get_val_by_len(&mut private_key_bytes)?;
let encrypted_priv_key_bytes = get_val_by_len(&mut private_key_bytes)?;
let encrypted_key_pair = EncryptedKeyPair {
kdf: String::from_utf8_lossy(&kdf).to_string(),
public_key: pub_key_bytes.to_vec(),
key_algorithm: key_alg_str.to_string(),
salt_bytes: salt_bytes.to_vec(),
nonce_bytes: nonce_bytes.to_vec(),
encrypted_private_key_bytes: encrypted_priv_key_bytes.to_vec(),
};
Ok((None, Some(encrypted_key_pair)))
}
}
pub fn load_identity_key(
priv_key_path: &PathBuf,
passphrase: Option<&str>,
) -> Result<IdentityKeyPair> {
let mut buffered_reader = File::open(priv_key_path)?;
let mut file_bytes = vec![];
let _len = buffered_reader.read_to_end(&mut file_bytes)?;
let decoded = STANDARD.decode(&file_bytes)?;
let mut private_key_bytes = BytesMut::from(&decoded[..]);
let magic_key = private_key_bytes.split_to(KEY_HEADER.len());
let magic_key_bytes = magic_key.freeze();
if &magic_key_bytes[..] != KEY_HEADER {
return Err(MoshpitError::InvalidKeyHeader.into());
}
let cipher = get_val_by_len(&mut private_key_bytes)?;
let kdf = get_val_by_len(&mut private_key_bytes)?;
let key_alg = get_val_by_len(&mut private_key_bytes)?;
let key_alg_str = std::str::from_utf8(&key_alg).map_err(|_| MoshpitError::InvalidKeyHeader)?;
if !is_supported_key_algorithm(key_alg_str) {
return Err(MoshpitError::InvalidKeyHeader.into());
}
let public_key = get_val_by_len(&mut private_key_bytes)?.to_vec();
let private_key = if cipher == NONE_CIPHER.as_bytes() && kdf == NONE_KDF.as_bytes() {
get_val_by_len(&mut private_key_bytes)?.to_vec()
} else {
let salt_bytes = get_val_by_len(&mut private_key_bytes)?;
let nonce_bytes = get_val_by_len(&mut private_key_bytes)?;
let encrypted_priv_key_bytes = get_val_by_len(&mut private_key_bytes)?;
let passphrase = passphrase.ok_or(MoshpitError::KeyCorrupt)?;
decrypt_private_key_to_vec(
passphrase,
&salt_bytes,
&nonce_bytes,
&encrypted_priv_key_bytes,
)?
};
validate_identity_key_pair(key_alg_str, &public_key, &private_key)?;
Ok(IdentityKeyPair {
key_algorithm: key_alg_str.to_string(),
public_key,
private_key,
})
}
pub fn validate_identity_key_pair(
key_alg: &str,
public_key: &[u8],
private_key: &[u8],
) -> Result<()> {
let agreement_alg: Option<&aws_lc_rs::agreement::Algorithm> = match key_alg {
KEY_ALGORITHM_X25519 => Some(&X25519),
KEY_ALGORITHM_P384 => Some(&ECDH_P384),
KEY_ALGORITHM_P256 => Some(&ECDH_P256),
_ => None,
};
if let Some(agreement_alg) = agreement_alg {
let private_key = PrivateKey::from_private_key(agreement_alg, private_key)?;
let computed_public_key = private_key.compute_public_key()?;
if computed_public_key.as_ref() == public_key {
return Ok(());
}
return Err(MoshpitError::PublicKeyMismatch.into());
}
#[cfg(feature = "unstable")]
if let Some(signing_alg) = resolve_pqdsa_signing_alg(key_alg) {
let key_pair = PqdsaKeyPair::from_raw_private_key(signing_alg, private_key)?;
if key_pair.public_key().as_ref() == public_key {
return Ok(());
}
return Err(MoshpitError::PublicKeyMismatch.into());
}
Err(MoshpitError::InvalidKeyHeader.into())
}
fn get_val_by_len(bytes: &mut BytesMut) -> Result<BytesMut> {
let len_bytes = usize::try_from(bytes.get_u32())?;
let val_bytes = bytes.split_to(len_bytes);
Ok(val_bytes)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use anyhow::Result;
use argon2::{Argon2, PasswordHash, PasswordVerifier as _};
use base64::Engine;
use super::{
decrypt_private_key, load_identity_key, load_private_key, validate_identity_key_pair,
};
#[test]
fn test_load_private_key_unenc() {
let priv_key_path = PathBuf::from("tests/keys/id_x25519_test");
let result = load_private_key(&priv_key_path);
assert!(result.is_ok());
let (unencrypted_key_pair_opt, encrypted_key_pair_opt) = result.unwrap();
assert!(unencrypted_key_pair_opt.is_some());
assert!(encrypted_key_pair_opt.is_none());
let unencrypted_key_pair = unencrypted_key_pair_opt.unwrap();
let public_key_bytes = unencrypted_key_pair.public_key.as_ref();
let expected_public_key_bytes = vec![
0x38, 0x43, 0x92, 0xD7, 0x3E, 0xEA, 0x2F, 0x77, 0x6B, 0x45, 0x7B, 0x99, 0xFD, 0xD6,
0x9D, 0x5B, 0x11, 0xF2, 0x3E, 0x8D, 0xB7, 0x13, 0x0B, 0xF7, 0x54, 0xF0, 0xC8, 0x49,
0x93, 0xD4, 0xF5, 0x5B,
];
assert_eq!(public_key_bytes, expected_public_key_bytes.as_slice());
}
#[test]
fn test_load_private_key_enc() -> Result<()> {
let priv_key_path = PathBuf::from("tests/keys/id_x25519_test_enc");
let result = load_private_key(&priv_key_path);
assert!(result.is_ok());
let (unencrypted_key_pair_opt, encrypted_key_pair_opt) = result.unwrap();
assert!(unencrypted_key_pair_opt.is_none());
assert!(encrypted_key_pair_opt.is_some());
let encrypted_key_pair = encrypted_key_pair_opt.unwrap();
assert!(encrypted_key_pair.kdf.starts_with("$argon2id$"));
let public_key_bytes = encrypted_key_pair.public_key.as_slice();
let expected_public_key_bytes = vec![
0x45, 0xDA, 0x9E, 0xCC, 0x73, 0xE8, 0x69, 0xE1, 0x98, 0xAF, 0xD9, 0x57, 0xD0, 0xAA,
0xA4, 0x2D, 0xA9, 0x52, 0xD0, 0x9C, 0xE3, 0x7B, 0x0A, 0x93, 0xEA, 0x9D, 0xDF, 0x6F,
0x4D, 0x54, 0x3F, 0x2F,
];
assert_eq!(public_key_bytes, expected_public_key_bytes.as_slice());
let parsed_hash = PasswordHash::new(&encrypted_key_pair.kdf)?;
let argon2 = Argon2::default();
assert!(argon2.verify_password(b"test", &parsed_hash).is_ok());
let salt_bytes = encrypted_key_pair.salt_bytes.as_slice();
let nonce_bytes = encrypted_key_pair.nonce_bytes.as_slice();
let encrypted_private_key_bytes = encrypted_key_pair.encrypted_private_key_bytes.clone();
let mut decrypted_bytes = encrypted_key_pair.encrypted_private_key_bytes.clone();
let decrypt_res =
decrypt_private_key("test", salt_bytes, nonce_bytes, &mut decrypted_bytes);
assert!(decrypt_res.is_ok());
assert_ne!(encrypted_private_key_bytes, decrypted_bytes);
Ok(())
}
#[test]
fn test_generate_key_pair_unencrypted() -> Result<()> {
let key_pair = super::KeyPair::generate_key_pair(
None,
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
super::KEY_ALGORITHM_X25519,
)?;
let mut priv_key_bytes = vec![];
key_pair.write_private_key(&mut priv_key_bytes)?;
let decoded = base64::engine::general_purpose::STANDARD.decode(&priv_key_bytes)?;
let mut buf = bytes::BytesMut::from(&decoded[..]);
let header = buf.split_to(super::KEY_HEADER.len());
assert_eq!(&header[..], super::KEY_HEADER);
let cipher = super::get_val_by_len(&mut buf)?;
let kdf = super::get_val_by_len(&mut buf)?;
assert_eq!(&cipher[..], super::NONE_CIPHER.as_bytes());
assert_eq!(&kdf[..], super::NONE_KDF.as_bytes());
Ok(())
}
#[cfg(feature = "unstable")]
#[test]
fn test_generate_and_load_ml_dsa_identity_key() -> Result<()> {
for key_alg in [
super::KEY_ALGORITHM_ML_DSA_44,
super::KEY_ALGORITHM_ML_DSA_65,
super::KEY_ALGORITHM_ML_DSA_87,
] {
let key_pair = super::KeyPair::generate_key_pair(
None,
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
key_alg,
)?;
let dir = tempfile::TempDir::new()?;
let key_path = dir.path().join("id_mldsa");
let mut private_key = std::fs::File::create(&key_path)?;
key_pair.write_private_key(&mut private_key)?;
let loaded = load_identity_key(&key_path, None)?;
assert_eq!(loaded.key_algorithm(), key_alg);
assert!(!key_pair.public_key_bytes().is_empty());
assert!(!loaded.public_key().is_empty());
assert!(!loaded.private_key().is_empty());
}
Ok(())
}
#[test]
fn test_load_identity_key_unenc_x25519() {
let path = PathBuf::from("tests/keys/id_x25519_test");
let key = load_identity_key(&path, None).expect("load unencrypted key");
assert_eq!(key.key_algorithm(), super::KEY_ALGORITHM_X25519);
let expected = vec![
0x38, 0x43, 0x92, 0xD7, 0x3E, 0xEA, 0x2F, 0x77, 0x6B, 0x45, 0x7B, 0x99, 0xFD, 0xD6,
0x9D, 0x5B, 0x11, 0xF2, 0x3E, 0x8D, 0xB7, 0x13, 0x0B, 0xF7, 0x54, 0xF0, 0xC8, 0x49,
0x93, 0xD4, 0xF5, 0x5B,
];
assert_eq!(key.public_key(), &expected);
assert!(!key.private_key().is_empty());
}
#[test]
fn test_load_identity_key_enc_x25519() {
let path = PathBuf::from("tests/keys/id_x25519_test_enc");
let key = load_identity_key(&path, Some("test")).expect("load encrypted key");
assert_eq!(key.key_algorithm(), super::KEY_ALGORITHM_X25519);
let expected = vec![
0x45, 0xDA, 0x9E, 0xCC, 0x73, 0xE8, 0x69, 0xE1, 0x98, 0xAF, 0xD9, 0x57, 0xD0, 0xAA,
0xA4, 0x2D, 0xA9, 0x52, 0xD0, 0x9C, 0xE3, 0x7B, 0x0A, 0x93, 0xEA, 0x9D, 0xDF, 0x6F,
0x4D, 0x54, 0x3F, 0x2F,
];
assert_eq!(key.public_key(), &expected);
assert!(!key.private_key().is_empty());
}
#[test]
fn test_load_identity_key_enc_wrong_passphrase() {
let path = PathBuf::from("tests/keys/id_x25519_test_enc");
assert!(load_identity_key(&path, Some("wrong")).is_err());
}
#[test]
fn test_load_identity_key_enc_no_passphrase() {
let path = PathBuf::from("tests/keys/id_x25519_test_enc");
assert!(load_identity_key(&path, None).is_err());
}
#[test]
fn test_load_identity_key_invalid_header() -> Result<()> {
let dir = tempfile::TempDir::new()?;
let path = dir.path().join("bad_key");
let garbage =
base64::engine::general_purpose::STANDARD.encode(b"wrong-header-for-testing-purposes");
std::fs::write(&path, garbage)?;
assert!(load_identity_key(&path, None).is_err());
Ok(())
}
#[test]
fn test_validate_identity_key_pair_mismatch() -> Result<()> {
let key = load_identity_key(&PathBuf::from("tests/keys/id_x25519_test"), None)?;
let wrong_pub = vec![0u8; 32];
assert!(
validate_identity_key_pair(key.key_algorithm(), &wrong_pub, key.private_key()).is_err()
);
Ok(())
}
#[test]
fn test_validate_identity_key_pair_unsupported_alg() {
assert!(validate_identity_key_pair("bogus-alg", &[0u8; 32], &[0u8; 32]).is_err());
}
#[test]
fn test_generate_key_pair_p384() -> Result<()> {
let key_pair = super::KeyPair::generate_key_pair(
None,
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
super::KEY_ALGORITHM_P384,
)?;
let dir = tempfile::TempDir::new()?;
let path = dir.path().join("id_p384");
let mut f = std::fs::File::create(&path)?;
key_pair.write_private_key(&mut f)?;
drop(f);
let loaded = load_identity_key(&path, None)?;
assert_eq!(loaded.key_algorithm(), super::KEY_ALGORITHM_P384);
validate_identity_key_pair(
super::KEY_ALGORITHM_P384,
loaded.public_key(),
loaded.private_key(),
)?;
Ok(())
}
#[test]
fn test_generate_key_pair_p256() -> Result<()> {
let key_pair = super::KeyPair::generate_key_pair(
None,
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
super::KEY_ALGORITHM_P256,
)?;
let dir = tempfile::TempDir::new()?;
let path = dir.path().join("id_p256");
let mut f = std::fs::File::create(&path)?;
key_pair.write_private_key(&mut f)?;
drop(f);
let loaded = load_identity_key(&path, None)?;
assert_eq!(loaded.key_algorithm(), super::KEY_ALGORITHM_P256);
validate_identity_key_pair(
super::KEY_ALGORITHM_P256,
loaded.public_key(),
loaded.private_key(),
)?;
Ok(())
}
#[test]
fn test_generate_key_pair_client_requires_passphrase() {
assert!(
super::KeyPair::generate_key_pair(
None,
super::KexMode::Client,
super::KEY_ALGORITHM_X25519,
)
.is_err()
);
}
#[test]
fn test_generate_key_pair_unknown_algorithm() {
assert!(
super::KeyPair::generate_key_pair(
None,
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
"unknown-alg",
)
.is_err()
);
}
#[test]
fn test_generate_key_pair_encrypted_x25519() -> Result<()> {
let passphrase = "my-test-passphrase".to_string();
let key_pair = super::KeyPair::generate_key_pair(
Some(&passphrase),
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
super::KEY_ALGORITHM_X25519,
)?;
let dir = tempfile::TempDir::new()?;
let path = dir.path().join("id_x25519_enc");
let mut f = std::fs::File::create(&path)?;
key_pair.write_private_key(&mut f)?;
drop(f);
let loaded = load_identity_key(&path, Some(&passphrase))?;
assert_eq!(loaded.key_algorithm(), super::KEY_ALGORITHM_X25519);
assert_eq!(loaded.public_key().len(), 32);
validate_identity_key_pair(
super::KEY_ALGORITHM_X25519,
loaded.public_key(),
loaded.private_key(),
)?;
Ok(())
}
#[cfg(feature = "unstable")]
#[test]
fn test_load_identity_key_enc_ml_dsa() -> Result<()> {
let passphrase = "ml-dsa-passphrase".to_string();
let key_pair = super::KeyPair::generate_key_pair(
Some(&passphrase),
super::KexMode::Server("0.0.0.0:0".parse().unwrap()),
super::KEY_ALGORITHM_ML_DSA_44,
)?;
let dir = tempfile::TempDir::new()?;
let path = dir.path().join("id_mldsa_enc");
let mut f = std::fs::File::create(&path)?;
key_pair.write_private_key(&mut f)?;
drop(f);
let loaded = load_identity_key(&path, Some(&passphrase))?;
assert_eq!(loaded.key_algorithm(), super::KEY_ALGORITHM_ML_DSA_44);
assert!(!loaded.public_key().is_empty());
assert!(!loaded.private_key().is_empty());
Ok(())
}
}