use base64::Engine;
use hpke::{
Deserializable, Kem, OpModeR, Serializable, aead::ChaCha20Poly1305, kdf::HkdfSha256,
kem::DhP256HkdfSha256,
};
use p256::{PublicKey, elliptic_curve::SecretKey, pkcs8::DecodePrivateKey};
use spki::EncodePublicKey;
use crate::KeyError;
pub struct PrivyHpke {
private_key: <DhP256HkdfSha256 as Kem>::PrivateKey,
public_key: <DhP256HkdfSha256 as Kem>::PublicKey,
}
impl PrivyHpke {
#[must_use]
pub fn new() -> Self {
let mut rng = rand::thread_rng();
let (private_key, public_key) = DhP256HkdfSha256::gen_keypair(&mut rng);
Self {
private_key,
public_key,
}
}
#[cfg(test)]
pub(crate) fn new_with_seed(seed: u64) -> Self {
use hpke::rand_core::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let (private_key, public_key) = DhP256HkdfSha256::gen_keypair(&mut rng);
Self {
private_key,
public_key,
}
}
pub fn public_key(&self) -> Result<String, KeyError> {
let public_key_bytes = self.public_key.to_bytes();
let p256_pk = PublicKey::from_sec1_bytes(&public_key_bytes)
.map_err(|_| KeyError::InvalidFormat("invalid SEC1 public key point".to_string()))?;
let spki_doc = p256_pk
.to_public_key_der()
.map_err(|_| KeyError::InvalidFormat("SPKI DER encoding failed".to_string()))?;
Ok(base64::engine::general_purpose::STANDARD.encode(spki_doc.as_bytes()))
}
pub fn decrypt_p256(
self,
encapsulated_key: &str,
ciphertext: &str,
) -> Result<SecretKey<p256::NistP256>, KeyError> {
let decrypted_key_bytes = self.decrypt_raw(encapsulated_key, ciphertext)?;
let key_b64 =
zeroize::Zeroizing::new(String::from_utf8(decrypted_key_bytes.to_vec()).map_err(
|_| KeyError::InvalidFormat("decrypted key is not valid UTF-8".to_string()),
)?);
#[cfg(all(feature = "unsafe_debug", debug_assertions))]
{
let truncated_key: String = key_b64.chars().take(6).collect();
tracing::debug!(
"Decrypted authorization key (base64 DER): {}",
truncated_key
);
}
let der_bytes = zeroize::Zeroizing::new(
base64::engine::general_purpose::STANDARD
.decode(key_b64.as_str())
.map_err(|_| {
KeyError::InvalidFormat("decrypted key is not valid base64".to_string())
})?,
);
SecretKey::<p256::NistP256>::from_pkcs8_der(&der_bytes).map_err(|e| {
tracing::error!("Failed to parse decrypted PKCS#8 DER key: {:?}", e);
KeyError::InvalidFormat("decrypted PKCS#8 DER key".to_string())
})
}
pub fn decrypt_raw(
self,
encapsulated_key: &str,
ciphertext: &str,
) -> Result<zeroize::Zeroizing<Vec<u8>>, KeyError> {
let encapped_key_bytes = base64::engine::general_purpose::STANDARD
.decode(encapsulated_key)
.map_err(|_| KeyError::InvalidFormat("base64 encapsulated key".to_string()))?;
let ciphertext_bytes = base64::engine::general_purpose::STANDARD
.decode(ciphertext)
.map_err(|_| KeyError::InvalidFormat("base64 ciphertext".to_string()))?;
tracing::debug!(
"Deserializing encapsulated key len {}",
encapped_key_bytes.len()
);
let encapped_key = <DhP256HkdfSha256 as Kem>::EncappedKey::from_bytes(&encapped_key_bytes)
.map_err(|e| {
tracing::error!("Failed to deserialize encapsulated key: {e:?}");
KeyError::InvalidFormat("encapsulated key".to_string())
})?;
let mut context = hpke::setup_receiver::<ChaCha20Poly1305, HkdfSha256, DhP256HkdfSha256>(
&OpModeR::Base,
&self.private_key,
&encapped_key,
&[],
)?;
Ok(zeroize::Zeroizing::new(
context.open(&ciphertext_bytes, &[])?,
))
}
}
impl Default for PrivyHpke {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use spki::DecodePublicKey;
use test_case::test_case;
use super::*;
#[test_case(0, "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAECT+o7IjvJ+4MjHTU51k5HLoXT9WKzjJKbqkGA3bcvx+ESEbM/wtxRDsptOMcsP+Vn60KdYOjIyLAU/P96CB2lA==" ; "zero")]
#[test_case(1, "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE5C1LvDxhkHINqB7lRM47O+sUIKTs/2YiPoNOQaRH2tnkhUjRC1x+g9yo0UZr/HzdJKNMAkSXRovCzovSr0jL3A==" ; "one")]
#[test_case(10, "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAECh6n0GOhDIloBuKZWx2/tPG3rX6oNuQdzH666gAYINFrZcC+GB/zICKGq+f7iXeobumsQiz38X8KKmOQoYkryA==" ; "ten")]
fn test_generate_key(seed: u64, expected: &str) {
let hpke = super::PrivyHpke::new_with_seed(seed);
let public_key = hpke.public_key().unwrap();
assert_eq!(public_key, expected);
}
#[test]
fn test_spki_round_trip() {
let hpke = PrivyHpke::new_with_seed(42);
let spki_b64 = hpke.public_key().unwrap();
let spki_der = base64::engine::general_purpose::STANDARD
.decode(&spki_b64)
.unwrap();
let decoded_pk = PublicKey::from_public_key_der(&spki_der).unwrap();
let original_pk_bytes = hpke.public_key.to_bytes();
let original_pk = PublicKey::from_sec1_bytes(&original_pk_bytes).unwrap();
assert_eq!(original_pk, decoded_pk);
}
#[test]
fn test_spki_structure_correctness() {
let hpke = PrivyHpke::new_with_seed(1337);
let spki_b64 = hpke.public_key().unwrap();
let spki_der = base64::engine::general_purpose::STANDARD
.decode(&spki_b64)
.unwrap();
assert_eq!(spki_der.len(), 91, "SPKI should be exactly 91 bytes");
assert_eq!(spki_der[0], 0x30, "Should start with SEQUENCE tag");
assert_eq!(
spki_der[1], 0x59,
"SEQUENCE should have 89 bytes of content"
);
let parsed_key = PublicKey::from_public_key_der(&spki_der).unwrap();
let original_pk_bytes = hpke.public_key.to_bytes();
let original_pk = PublicKey::from_sec1_bytes(&original_pk_bytes).unwrap();
assert_eq!(original_pk, parsed_key);
}
#[test]
fn test_spki_base64_format() {
let hpke = PrivyHpke::new_with_seed(999);
let spki_b64 = hpke.public_key().unwrap();
let decoded = base64::engine::general_purpose::STANDARD
.decode(&spki_b64)
.expect("Should be valid base64");
assert_eq!(decoded.len(), 91);
let re_encoded = base64::engine::general_purpose::STANDARD.encode(&decoded);
assert_eq!(spki_b64, re_encoded);
}
#[test]
fn test_error_handling_invalid_key() {
use p256::PublicKey;
let invalid_bytes = vec![0x04; 32]; let result = PublicKey::from_sec1_bytes(&invalid_bytes);
assert!(result.is_err(), "Should reject invalid SEC1 bytes");
let invalid_bytes = vec![0x02; 65]; let result = PublicKey::from_sec1_bytes(&invalid_bytes);
assert!(result.is_err(), "Should reject invalid format indicator");
}
#[test]
fn test_hpke_decrypt_success() {
use hpke::{OpModeS, Serializable};
use p256::pkcs8::EncodePrivateKey;
let receiver = PrivyHpke::new_with_seed(42);
let mut rng = rand::thread_rng();
let (encapped_key, mut sender_ctx) =
hpke::setup_sender::<ChaCha20Poly1305, HkdfSha256, DhP256HkdfSha256, _>(
&OpModeS::Base,
&receiver.public_key,
&[],
&mut rng,
)
.expect("Failed to setup sender");
let test_key = SecretKey::<p256::NistP256>::random(&mut rng);
let test_key_pkcs8_der = test_key
.to_pkcs8_der()
.expect("Failed to encode test key as PKCS#8 DER");
let test_key_b64 =
base64::engine::general_purpose::STANDARD.encode(test_key_pkcs8_der.as_bytes());
let ciphertext_bytes = sender_ctx
.seal(test_key_b64.as_bytes(), &[])
.expect("Failed to encrypt");
let encapped_key_b64 =
base64::engine::general_purpose::STANDARD.encode(encapped_key.to_bytes());
let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(&ciphertext_bytes);
let decrypted_raw = receiver
.decrypt_raw(&encapped_key_b64, &ciphertext_b64)
.expect("Failed to decrypt raw");
assert_eq!(
&*decrypted_raw,
test_key_b64.as_bytes(),
"Decrypted raw bytes should match original"
);
let receiver2 = PrivyHpke::new_with_seed(42);
let decrypted_key = receiver2
.decrypt_p256(&encapped_key_b64, &ciphertext_b64)
.expect("Failed to decrypt as P256 key");
assert_eq!(
decrypted_key.to_bytes(),
test_key.to_bytes(),
"Decrypted key should match original key"
);
}
#[test]
fn test_hpke_decrypt_invalid_ciphertext() {
let receiver = PrivyHpke::new_with_seed(100);
use hpke::{OpModeS, Serializable};
let mut rng = rand::thread_rng();
let (encapped_key, _) = hpke::setup_sender::<
ChaCha20Poly1305,
HkdfSha256,
DhP256HkdfSha256,
_,
>(&OpModeS::Base, &receiver.public_key, &[], &mut rng)
.expect("Failed to setup sender");
let encapped_key_b64 =
base64::engine::general_purpose::STANDARD.encode(encapped_key.to_bytes());
let invalid_ciphertext = vec![0x99; 64];
let invalid_ciphertext_b64 =
base64::engine::general_purpose::STANDARD.encode(&invalid_ciphertext);
let result = receiver.decrypt_raw(&encapped_key_b64, &invalid_ciphertext_b64);
assert!(
result.is_err(),
"Decryption with invalid ciphertext should fail"
);
match result {
Err(KeyError::HpkeDecryption(_)) => {
}
Err(other) => panic!("Expected HpkeDecryption error, got: {other:?}"),
Ok(_) => panic!("Expected error but got success"),
}
}
#[test]
fn test_hpke_decrypt_invalid_encapsulated_key() {
let receiver = PrivyHpke::new_with_seed(200);
let invalid_encapped_key = vec![0x04; 32]; let invalid_encapped_key_b64 =
base64::engine::general_purpose::STANDARD.encode(&invalid_encapped_key);
let ciphertext = vec![0x00; 64];
let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(&ciphertext);
let result = receiver.decrypt_raw(&invalid_encapped_key_b64, &ciphertext_b64);
assert!(
result.is_err(),
"Decryption with invalid encapsulated key should fail"
);
match result {
Err(KeyError::InvalidFormat(msg)) => {
assert!(
msg.contains("encapsulated key"),
"Error message should mention encapsulated key"
);
}
Err(other) => panic!("Expected InvalidFormat error, got: {other:?}"),
Ok(_) => panic!("Expected error but got success"),
}
}
}