use anyhow::{Context, Result, anyhow};
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
use hkdf::Hkdf;
use rand_core::{OsRng, RngCore};
use sha2::Sha256;
use std::collections::HashMap;
use std::sync::Mutex;
use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
use zeroize::Zeroize;
pub const ENVELOPE_VERSION: u8 = 0x02;
pub const PUBKEY_LEN: usize = 32;
pub const NONCE_LEN: usize = 12;
pub const TAG_LEN: usize = 16;
pub const AEAD_KEY_LEN: usize = 32;
const HKDF_INFO: &[u8] = b"ai-memory/v0.7.0/e2e-content/chacha20poly1305-key/v2";
#[derive(Clone)]
pub struct Keypair {
pub agent_id: String,
pub public: PublicKey,
pub secret: StaticSecret,
}
impl std::fmt::Debug for Keypair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Keypair")
.field("agent_id", &self.agent_id)
.field("public", &"<x25519 pubkey>")
.field("secret", &crate::REDACTED_PLACEHOLDER)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct Envelope {
pub ephemeral_pub: [u8; PUBKEY_LEN],
pub nonce: [u8; NONCE_LEN],
pub ciphertext: Vec<u8>,
}
impl Envelope {
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(1 + PUBKEY_LEN + NONCE_LEN + self.ciphertext.len());
out.push(ENVELOPE_VERSION);
out.extend_from_slice(&self.ephemeral_pub);
out.extend_from_slice(&self.nonce);
out.extend_from_slice(&self.ciphertext);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let header_len = 1 + PUBKEY_LEN + NONCE_LEN;
if bytes.len() < header_len + TAG_LEN {
return Err(anyhow!(
"envelope buffer too short: got {} bytes, need at least {}",
bytes.len(),
header_len + TAG_LEN
));
}
if bytes[0] != ENVELOPE_VERSION {
return Err(anyhow!(
"unknown envelope version: got 0x{:02x}, expected 0x{:02x}",
bytes[0],
ENVELOPE_VERSION
));
}
let mut ephemeral_pub = [0u8; PUBKEY_LEN];
ephemeral_pub.copy_from_slice(&bytes[1..1 + PUBKEY_LEN]);
let mut nonce = [0u8; NONCE_LEN];
nonce.copy_from_slice(&bytes[1 + PUBKEY_LEN..header_len]);
let ciphertext = bytes[header_len..].to_vec();
Ok(Envelope {
ephemeral_pub,
nonce,
ciphertext,
})
}
}
fn keypair_cache() -> &'static Mutex<HashMap<String, Keypair>> {
&crate::runtime_context::RuntimeContext::global().keypair_cache
}
pub fn get_or_create_keypair(agent_id: &str) -> Result<Keypair> {
let cache = keypair_cache();
let mut guard = cache
.lock()
.map_err(|e| anyhow!("encryption keypair cache mutex poisoned: {e}"))?;
if let Some(kp) = guard.get(agent_id) {
return Ok(kp.clone());
}
let secret = StaticSecret::random_from_rng(OsRng);
let public = PublicKey::from(&secret);
let kp = Keypair {
agent_id: agent_id.to_string(),
public,
secret,
};
guard.insert(agent_id.to_string(), kp.clone());
Ok(kp)
}
fn derive_aead_key(shared: &SharedSecret) -> [u8; AEAD_KEY_LEN] {
let hk = Hkdf::<Sha256>::new(None, shared.as_bytes());
let mut okm = [0u8; AEAD_KEY_LEN];
hk.expand(HKDF_INFO, &mut okm)
.expect("HKDF expand of AEAD_KEY_LEN bytes is within the 255*HashLen limit");
okm
}
fn envelope_aad(ephemeral_pub: &[u8; PUBKEY_LEN]) -> [u8; 1 + PUBKEY_LEN] {
let mut aad = [0u8; 1 + PUBKEY_LEN];
aad[0] = ENVELOPE_VERSION;
aad[1..].copy_from_slice(ephemeral_pub);
aad
}
pub fn encrypt(content: &str, recipient_pk: &PublicKey) -> Result<Envelope> {
let ephemeral_secret = StaticSecret::random_from_rng(OsRng);
let ephemeral_public = PublicKey::from(&ephemeral_secret);
let shared = ephemeral_secret.diffie_hellman(recipient_pk);
let mut okm = derive_aead_key(&shared);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&okm));
okm.zeroize();
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ephemeral_pub = ephemeral_public.to_bytes();
let aad = envelope_aad(&ephemeral_pub);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: content.as_bytes(),
aad: &aad,
},
)
.map_err(|e| anyhow!("ChaCha20-Poly1305 encrypt failed: {e}"))?;
Ok(Envelope {
ephemeral_pub,
nonce: nonce_bytes,
ciphertext,
})
}
pub fn decrypt(envelope: &Envelope, my_sk: &StaticSecret) -> Result<String> {
let ephemeral_public = PublicKey::from(envelope.ephemeral_pub);
let shared = my_sk.diffie_hellman(&ephemeral_public);
let mut okm = derive_aead_key(&shared);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&okm));
okm.zeroize();
let nonce = Nonce::from_slice(&envelope.nonce);
let aad = envelope_aad(&envelope.ephemeral_pub);
let plaintext = cipher
.decrypt(
nonce,
Payload {
msg: &envelope.ciphertext,
aad: &aad,
},
)
.map_err(|e| anyhow!("ChaCha20-Poly1305 decrypt failed (authentication): {e}"))?;
String::from_utf8(plaintext).context("decrypted plaintext is not valid UTF-8")
}
#[must_use]
pub fn encryption_enabled(config_flag: Option<bool>) -> bool {
if let Some(true) = config_flag {
return true;
}
matches!(
std::env::var("AI_MEMORY_ENCRYPT_AT_REST")
.ok()
.as_deref()
.map(str::to_ascii_lowercase)
.as_deref(),
Some("1" | "true" | "yes" | "on")
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn keypair_round_trip_returns_same_secret() {
let agent = "test-agent-roundtrip";
let a = get_or_create_keypair(agent).expect("first generate");
let b = get_or_create_keypair(agent).expect("second fetch");
assert_eq!(a.public.as_bytes(), b.public.as_bytes());
assert_eq!(a.secret.to_bytes(), b.secret.to_bytes());
}
#[test]
fn keypair_distinct_for_distinct_agents() {
let a = get_or_create_keypair("agent-a").expect("a");
let b = get_or_create_keypair("agent-b").expect("b");
assert_ne!(a.public.as_bytes(), b.public.as_bytes());
}
#[test]
fn encrypt_decrypt_round_trip_recovers_plaintext() {
let kp = get_or_create_keypair("roundtrip-agent").expect("keypair");
let plaintext = "hello world — encryption substrate MVP";
let env = encrypt(plaintext, &kp.public).expect("encrypt");
let recovered = decrypt(&env, &kp.secret).expect("decrypt");
assert_eq!(recovered, plaintext);
}
#[test]
fn envelope_wire_format_round_trips() {
let kp = get_or_create_keypair("envelope-bytes").expect("kp");
let env = encrypt("payload bytes", &kp.public).expect("encrypt");
let bytes = env.to_bytes();
let parsed = Envelope::from_bytes(&bytes).expect("parse");
assert_eq!(env.ephemeral_pub, parsed.ephemeral_pub);
assert_eq!(env.nonce, parsed.nonce);
assert_eq!(env.ciphertext, parsed.ciphertext);
let recovered = decrypt(&parsed, &kp.secret).expect("decrypt parsed");
assert_eq!(recovered, "payload bytes");
}
#[test]
fn envelope_parse_rejects_short_buffer() {
assert!(Envelope::from_bytes(&[]).is_err());
assert!(Envelope::from_bytes(&[0x01; 10]).is_err());
}
#[test]
fn envelope_parse_rejects_unknown_version() {
let mut bad = vec![0xFF];
bad.extend_from_slice(&[0u8; PUBKEY_LEN + NONCE_LEN + TAG_LEN + 1]);
assert!(Envelope::from_bytes(&bad).is_err());
}
#[test]
fn decrypt_with_wrong_secret_fails() {
let kp_alice = get_or_create_keypair("alice-wrong-key").expect("alice");
let kp_eve = get_or_create_keypair("eve-wrong-key").expect("eve");
let env = encrypt("secret-for-alice", &kp_alice.public).expect("encrypt");
assert!(decrypt(&env, &kp_eve.secret).is_err());
}
#[test]
fn decrypt_with_tampered_ciphertext_fails() {
let kp = get_or_create_keypair("tamper-detect").expect("kp");
let mut env = encrypt("dont change this", &kp.public).expect("encrypt");
env.ciphertext[0] ^= 0x01;
assert!(decrypt(&env, &kp.secret).is_err());
}
#[test]
fn hkdf_derived_key_is_deterministic_and_differs_from_raw_shared_secret() {
let alice = get_or_create_keypair("h3-hkdf-alice").expect("alice");
let bob = get_or_create_keypair("h3-hkdf-bob").expect("bob");
let shared_a = alice.secret.diffie_hellman(&bob.public);
let shared_b = bob.secret.diffie_hellman(&alice.public);
assert_eq!(shared_a.as_bytes(), shared_b.as_bytes());
let key1 = derive_aead_key(&shared_a);
let key2 = derive_aead_key(&shared_b);
assert_eq!(key1, key2, "HKDF derivation must be deterministic");
assert_eq!(key1.len(), AEAD_KEY_LEN);
assert_ne!(
&key1,
shared_a.as_bytes(),
"derived key must not be the raw shared secret (HKDF must transform it)"
);
}
#[test]
fn envelope_aad_binds_version_and_ephemeral_pub() {
let pubkey = [7u8; PUBKEY_LEN];
let aad = envelope_aad(&pubkey);
assert_eq!(aad.len(), 1 + PUBKEY_LEN);
assert_eq!(aad[0], ENVELOPE_VERSION, "AAD[0] must pin the version");
assert_eq!(&aad[1..], &pubkey, "AAD tail must be the ephemeral pubkey");
}
#[test]
fn decrypt_fails_when_ephemeral_pub_swapped() {
let kp = get_or_create_keypair("h3-aad-swap").expect("kp");
let mut env = encrypt("aad-bound payload", &kp.public).expect("encrypt");
let other = get_or_create_keypair("h3-aad-swap-other").expect("other");
env.ephemeral_pub = other.public.to_bytes();
assert!(
decrypt(&env, &kp.secret).is_err(),
"a swapped ephemeral pubkey must fail AEAD authentication"
);
}
#[test]
fn envelope_version_is_the_hkdf_aad_scheme() {
let kp = get_or_create_keypair("h3-version-pin").expect("kp");
let env = encrypt("scheme marker", &kp.public).expect("encrypt");
assert_eq!(ENVELOPE_VERSION, 0x02);
assert_eq!(env.to_bytes()[0], 0x02, "wire version byte must be 0x02");
}
#[test]
fn encryption_enabled_config_flag_wins() {
let prev = std::env::var("AI_MEMORY_ENCRYPT_AT_REST").ok();
unsafe { std::env::remove_var("AI_MEMORY_ENCRYPT_AT_REST") };
assert!(encryption_enabled(Some(true)));
assert!(!encryption_enabled(Some(false)));
assert!(!encryption_enabled(None));
unsafe { std::env::set_var("AI_MEMORY_ENCRYPT_AT_REST", "1") };
assert!(encryption_enabled(None));
assert!(encryption_enabled(Some(true)));
if let Some(v) = prev {
unsafe { std::env::set_var("AI_MEMORY_ENCRYPT_AT_REST", v) };
} else {
unsafe { std::env::remove_var("AI_MEMORY_ENCRYPT_AT_REST") };
}
}
}