use std::io::{Read, Write};
use hpke::aead::ChaCha20Poly1305 as ChaCha20Poly1305_;
use hpke::kdf::HkdfSha256;
use hpke::kem::{Kem, X25519HkdfSha256};
use hpke::{Deserializable, OpModeR, OpModeS, Serializable};
use rand::{SeedableRng, rngs::StdRng};
use secrecy::{SecretSlice, zeroize::Zeroizing};
use crate::cipher::ChaCha20Poly1305;
use crate::traits::{self, Cipher, Error, GeneratedKey};
const HEADER: &[u8; 5] = b"HPKE\x01";
const INFO: &[u8] = b"jolokia-hpke-stream-v1";
const EXPORT_LABEL: &[u8] = b"stream";
pub struct Hpke;
impl Cipher for Hpke {
fn generate_key(&self) -> GeneratedKey {
let mut csprng = StdRng::from_os_rng();
let (sk, pk) = <X25519HkdfSha256 as Kem>::gen_keypair(&mut csprng);
GeneratedKey::Asymmetric {
public: SecretSlice::from(pk.to_bytes().to_vec()),
private: SecretSlice::from(sk.to_bytes().to_vec()),
}
}
fn encrypt_stream(
&self,
public_key: &[u8],
reader: &mut dyn Read,
writer: &mut dyn Write,
) -> traits::Result<()> {
let public_key = <X25519HkdfSha256 as Kem>::PublicKey::from_bytes(public_key)
.map_err(|_| Error::Encrypt)?;
let mut csprng = StdRng::from_os_rng();
let (encapsulated_public_key, encryption_context) =
hpke::setup_sender::<ChaCha20Poly1305_, HkdfSha256, X25519HkdfSha256, _>(
&OpModeS::Base,
&public_key,
INFO,
&mut csprng,
)
.map_err(|_| Error::Encrypt)?;
let mut symmetric_key = Zeroizing::new([0u8; 32]);
encryption_context
.export(EXPORT_LABEL, symmetric_key.as_mut_slice())
.map_err(|_| Error::Encrypt)?;
writer
.write_all(HEADER)
.map_err(|e| Error::Write(e.to_string()))?;
let encapsulated_public_key = encapsulated_public_key.to_bytes();
let encapsulated_public_key_len = u16::try_from(encapsulated_public_key.len())
.map_err(|_| Error::Encrypt)?
.to_be_bytes();
writer
.write_all(&encapsulated_public_key_len)
.map_err(|e| Error::Write(e.to_string()))?;
writer
.write_all(&encapsulated_public_key)
.map_err(|e| Error::Write(e.to_string()))?;
ChaCha20Poly1305.encrypt_stream(symmetric_key.as_ref(), reader, writer)?;
Ok(())
}
fn decrypt_stream(
&self,
private_key: &[u8],
reader: &mut dyn Read,
writer: &mut dyn Write,
) -> traits::Result<()> {
if usize::BITS < u16::BITS {
return Err(Error::Platform(
"< 16-bit platforms are not supported.".to_string(),
));
}
let private_key = <X25519HkdfSha256 as Kem>::PrivateKey::from_bytes(private_key)
.map_err(|_| Error::Decrypt)?;
let mut header = [0u8; HEADER.len()];
reader
.read_exact(&mut header)
.map_err(|e| Error::Read(e.to_string()))?;
if &header != HEADER {
return Err(Error::Algorithm);
}
let mut encapsulated_public_key_len = [0u8; 2];
reader
.read_exact(&mut encapsulated_public_key_len)
.map_err(|e| Error::Read(e.to_string()))?;
let encapsulated_public_key_len = u16::from_be_bytes(encapsulated_public_key_len) as usize;
if encapsulated_public_key_len != 32 {
return Err(Error::Decrypt);
}
let mut encapsulated_public_key = vec![0u8; encapsulated_public_key_len];
reader
.read_exact(&mut encapsulated_public_key)
.map_err(|e| Error::Read(e.to_string()))?;
let encapsulated_public_key =
<X25519HkdfSha256 as Kem>::EncappedKey::from_bytes(&encapsulated_public_key)
.map_err(|_| Error::Decrypt)?;
let decryption_context = hpke::setup_receiver::<
ChaCha20Poly1305_,
HkdfSha256,
X25519HkdfSha256,
>(
&OpModeR::Base, &private_key, &encapsulated_public_key, INFO
)
.map_err(|_| Error::Decrypt)?;
let mut symmetric_key = Zeroizing::new([0u8; 32]);
decryption_context
.export(EXPORT_LABEL, symmetric_key.as_mut_slice())
.map_err(|_| Error::Decrypt)?;
ChaCha20Poly1305.decrypt_stream(symmetric_key.as_ref(), reader, writer)?;
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use std::io::Cursor;
use super::*;
use crate::traits::Base64Decode;
#[test]
fn hpke_encrypt_decrypt_roundtrip() {
let public_key = "lNLRjAfH2i8QfgEBmkwb9DyigB6mFae94FYCx46qij0"
.base64_decode()
.unwrap();
let private_key = "caEdcM9zySxJCc+HBD7QzzpJwBVWm2BcGyBMoGETi+g"
.base64_decode()
.unwrap();
let plaintext = b"hello, world!";
let encrypted = Hpke.encrypt(&public_key, plaintext).unwrap();
let decrypted = Hpke.decrypt(&private_key, &encrypted).unwrap();
let decrypted = String::from_utf8_lossy(&decrypted);
assert_eq!(decrypted, "hello, world!");
}
#[test]
fn hpke_encrypt_decrypt_streaming_roundtrip_shorter_than_a_chunk() {
let public_key = "lNLRjAfH2i8QfgEBmkwb9DyigB6mFae94FYCx46qij0"
.base64_decode()
.unwrap();
let private_key = "caEdcM9zySxJCc+HBD7QzzpJwBVWm2BcGyBMoGETi+g"
.base64_decode()
.unwrap();
let plaintext = b"hello, world!";
assert!(plaintext.len() < 4096, "{} >= 4096", plaintext.len());
let mut encrypted = Vec::new();
Hpke.encrypt_stream(&public_key, &mut Cursor::new(plaintext), &mut encrypted)
.unwrap();
dbg!(&encrypted);
assert!(encrypted.len() > 8);
let mut decrypted = Vec::new();
Hpke.decrypt_stream(&private_key, &mut Cursor::new(encrypted), &mut decrypted)
.unwrap();
let decrypted = String::from_utf8_lossy(&decrypted);
dbg!(&decrypted);
assert_eq!(decrypted, "hello, world!");
}
#[test]
fn hpke_encrypt_decrypt_streaming_roundtrip_same_length_as_a_chunk() {
let public_key = "lNLRjAfH2i8QfgEBmkwb9DyigB6mFae94FYCx46qij0"
.base64_decode()
.unwrap();
let private_key = "caEdcM9zySxJCc+HBD7QzzpJwBVWm2BcGyBMoGETi+g"
.base64_decode()
.unwrap();
let mut plaintext = b"hello, world!".repeat(315);
plaintext.extend(b"1");
assert_eq!(plaintext.len(), 4096);
let mut encrypted = Vec::new();
Hpke.encrypt_stream(&public_key, &mut Cursor::new(plaintext), &mut encrypted)
.unwrap();
dbg!(&encrypted);
assert!(encrypted.len() > 8);
let mut decrypted = Vec::new();
Hpke.decrypt_stream(&private_key, &mut Cursor::new(encrypted), &mut decrypted)
.unwrap();
let decrypted = String::from_utf8_lossy(&decrypted);
dbg!(&decrypted);
assert_eq!(decrypted, "hello, world!".repeat(315) + "1");
}
#[test]
fn hpke_encrypt_decrypt_streaming_roundtrip_longer_than_a_chunk() {
let public_key = "lNLRjAfH2i8QfgEBmkwb9DyigB6mFae94FYCx46qij0"
.base64_decode()
.unwrap();
let private_key = "caEdcM9zySxJCc+HBD7QzzpJwBVWm2BcGyBMoGETi+g"
.base64_decode()
.unwrap();
let plaintext = b"hello, world!".repeat(320);
assert!(plaintext.len() > 4096, "{} <= 4096", plaintext.len());
let mut encrypted = Vec::new();
Hpke.encrypt_stream(&public_key, &mut Cursor::new(plaintext), &mut encrypted)
.unwrap();
dbg!(&encrypted);
assert!(encrypted.len() > 8);
let mut decrypted = Vec::new();
Hpke.decrypt_stream(&private_key, &mut Cursor::new(encrypted), &mut decrypted)
.unwrap();
let decrypted = String::from_utf8_lossy(&decrypted);
dbg!(&decrypted);
assert_eq!(decrypted, "hello, world!".repeat(320));
}
}