#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
ChaCha20Poly1305, Nonce,
};
use hkdf::Hkdf;
use rand_core::RngCore;
use sha2::Sha256;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EncryptionError {
EncryptionFailed,
DecryptionFailed,
InvalidFormat,
}
impl core::fmt::Display for EncryptionError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::EncryptionFailed => write!(f, "encryption failed"),
Self::DecryptionFailed => write!(f, "decryption failed (wrong key or corrupted data)"),
Self::InvalidFormat => write!(f, "invalid encrypted document format"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for EncryptionError {}
#[derive(Debug, Clone)]
pub struct EncryptedDocument {
pub nonce: [u8; 12],
pub ciphertext: Vec<u8>,
}
impl EncryptedDocument {
pub const OVERHEAD: usize = 12 + 16;
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(12 + self.ciphertext.len());
buf.extend_from_slice(&self.nonce);
buf.extend_from_slice(&self.ciphertext);
buf
}
pub fn decode(data: &[u8]) -> Option<Self> {
if data.len() < Self::OVERHEAD {
return None;
}
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&data[..12]);
let ciphertext = data[12..].to_vec();
Some(Self { nonce, ciphertext })
}
}
#[derive(Clone)]
pub struct MeshEncryptionKey {
key: [u8; 32],
}
impl MeshEncryptionKey {
const HKDF_INFO: &'static [u8] = b"PEAT-BTLE-mesh-encryption-v1";
pub fn from_shared_secret(mesh_id: &str, secret: &[u8; 32]) -> Self {
let hk = Hkdf::<Sha256>::new(Some(mesh_id.as_bytes()), secret);
let mut key = [0u8; 32];
hk.expand(Self::HKDF_INFO, &mut key)
.expect("32 bytes is valid output length for HKDF-SHA256");
Self { key }
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedDocument, EncryptionError> {
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|_| EncryptionError::EncryptionFailed)?;
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|_| EncryptionError::EncryptionFailed)?;
Ok(EncryptedDocument {
nonce: nonce_bytes,
ciphertext,
})
}
pub fn decrypt(&self, encrypted: &EncryptedDocument) -> Result<Vec<u8>, EncryptionError> {
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|_| EncryptionError::DecryptionFailed)?;
let nonce = Nonce::from_slice(&encrypted.nonce);
cipher
.decrypt(nonce, encrypted.ciphertext.as_ref())
.map_err(|_| EncryptionError::DecryptionFailed)
}
pub fn encrypt_to_bytes(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let encrypted = self.encrypt(plaintext)?;
Ok(encrypted.encode())
}
pub fn decrypt_from_bytes(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let encrypted = EncryptedDocument::decode(data).ok_or(EncryptionError::InvalidFormat)?;
self.decrypt(&encrypted)
}
}
impl core::fmt::Debug for MeshEncryptionKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("MeshEncryptionKey")
.field("key", &"[REDACTED]")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_derivation_deterministic() {
let secret = [0x42u8; 32];
let key1 = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let key2 = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
assert_eq!(key1.key, key2.key);
}
#[test]
fn test_key_derivation_different_mesh_id() {
let secret = [0x42u8; 32];
let key1 = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let key2 = MeshEncryptionKey::from_shared_secret("ALPHA", &secret);
assert_ne!(key1.key, key2.key);
}
#[test]
fn test_key_derivation_different_secret() {
let secret1 = [0x42u8; 32];
let secret2 = [0x43u8; 32];
let key1 = MeshEncryptionKey::from_shared_secret("DEMO", &secret1);
let key2 = MeshEncryptionKey::from_shared_secret("DEMO", &secret2);
assert_ne!(key1.key, key2.key);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Hello, Peat mesh!";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_encrypt_decrypt_empty() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_encrypt_produces_different_ciphertext() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Same message";
let encrypted1 = key.encrypt(plaintext).unwrap();
let encrypted2 = key.encrypt(plaintext).unwrap();
assert_ne!(encrypted1.nonce, encrypted2.nonce);
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
assert_eq!(key.decrypt(&encrypted1).unwrap(), plaintext.as_slice());
assert_eq!(key.decrypt(&encrypted2).unwrap(), plaintext.as_slice());
}
#[test]
fn test_wrong_key_fails() {
let secret1 = [0x42u8; 32];
let secret2 = [0x43u8; 32];
let key1 = MeshEncryptionKey::from_shared_secret("DEMO", &secret1);
let key2 = MeshEncryptionKey::from_shared_secret("DEMO", &secret2);
let plaintext = b"Secret message";
let encrypted = key1.encrypt(plaintext).unwrap();
let result = key2.decrypt(&encrypted);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), EncryptionError::DecryptionFailed);
}
#[test]
fn test_tampered_ciphertext_fails() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Authentic message";
let mut encrypted = key.encrypt(plaintext).unwrap();
if !encrypted.ciphertext.is_empty() {
encrypted.ciphertext[0] ^= 0xFF;
}
let result = key.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_encrypted_document_encode_decode() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Wire format test";
let encrypted = key.encrypt(plaintext).unwrap();
let wire_bytes = encrypted.encode();
let decoded = EncryptedDocument::decode(&wire_bytes).unwrap();
assert_eq!(encrypted.nonce, decoded.nonce);
assert_eq!(encrypted.ciphertext, decoded.ciphertext);
let decrypted = key.decrypt(&decoded).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_convenience_methods() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Convenience test";
let wire_bytes = key.encrypt_to_bytes(plaintext).unwrap();
let decrypted = key.decrypt_from_bytes(&wire_bytes).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_encrypted_document_decode_too_short() {
let short_data = [0u8; 27];
assert!(EncryptedDocument::decode(&short_data).is_none());
let minimal_data = [0u8; 28];
assert!(EncryptedDocument::decode(&minimal_data).is_some());
}
#[test]
fn test_overhead_calculation() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let plaintext = b"Testing overhead";
let encrypted = key.encrypt(plaintext).unwrap();
let wire_bytes = encrypted.encode();
let expected_size = 12 + plaintext.len() + 16;
assert_eq!(wire_bytes.len(), expected_size);
assert_eq!(
wire_bytes.len() - plaintext.len(),
EncryptedDocument::OVERHEAD
);
}
#[test]
fn test_debug_redacts_key() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let debug_str = format!("{:?}", key);
assert!(debug_str.contains("REDACTED"));
assert!(!debug_str.contains("42")); }
#[test]
fn test_realistic_document_size() {
let secret = [0x42u8; 32];
let key = MeshEncryptionKey::from_shared_secret("DEMO", &secret);
let doc = vec![0xABu8; 100];
let encrypted = key.encrypt(&doc).unwrap();
let wire_bytes = encrypted.encode();
assert_eq!(wire_bytes.len(), 128);
assert!(wire_bytes.len() < 244);
assert!(wire_bytes.len() < 512);
}
}