use std::collections::HashMap;
use std::sync::Arc;
use aes_gcm::{
Aes256Gcm, KeyInit, Nonce,
aead::{Aead, Payload},
};
use axess_identity::{TenantId, UserId};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use crate::delegated::stored::credential::{DelegatedCredentialStore, StoredDelegation};
use axess_factors::ZeroizedString;
use axess_rng::SecureRng;
const ENVELOPE_VERSION: &str = "v1";
const NONCE_LEN: usize = 12;
const FIELD_ACCESS: &[u8] = b"access_token";
const FIELD_REFRESH: &[u8] = b"refresh_token";
const AAD_SEP: u8 = 0x1f;
pub struct EncryptionKey([u8; 32]);
impl EncryptionKey {
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
fn as_array(&self) -> &[u8; 32] {
&self.0
}
}
impl Drop for EncryptionKey {
fn drop(&mut self) {
zeroize::Zeroize::zeroize(&mut self.0);
}
}
impl core::fmt::Debug for EncryptionKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("EncryptionKey(***)")
}
}
pub trait KeyProvider: Send + Sync + 'static {
fn current(&self) -> Result<CurrentKey, KeyProviderError>;
fn resolve(&self, key_id: &str) -> Result<Option<Arc<EncryptionKey>>, KeyProviderError>;
}
#[derive(Clone)]
pub struct CurrentKey {
pub key_id: Arc<str>,
pub key: Arc<EncryptionKey>,
}
#[derive(Debug, thiserror::Error)]
pub enum KeyProviderError {
#[error("key provider failed: {0}")]
Failed(String),
#[error("key id contains '.', which is reserved by the envelope format: {0:?}")]
InvalidKeyId(String),
}
#[derive(Clone, Debug)]
pub struct MemoryKeyProvider {
current_id: Arc<str>,
current_key: Arc<EncryptionKey>,
historical: HashMap<String, Arc<EncryptionKey>>,
}
impl MemoryKeyProvider {
pub fn new(key_id: impl Into<String>, key: [u8; 32]) -> Result<Self, KeyProviderError> {
let id = key_id.into();
validate_key_id(&id)?;
Ok(Self {
current_id: Arc::from(id),
current_key: Arc::new(EncryptionKey::from_bytes(key)),
historical: HashMap::new(),
})
}
pub fn with_historical(
mut self,
key_id: impl Into<String>,
key: [u8; 32],
) -> Result<Self, KeyProviderError> {
let id = key_id.into();
validate_key_id(&id)?;
self.historical
.insert(id, Arc::new(EncryptionKey::from_bytes(key)));
Ok(self)
}
}
impl KeyProvider for MemoryKeyProvider {
fn current(&self) -> Result<CurrentKey, KeyProviderError> {
Ok(CurrentKey {
key_id: self.current_id.clone(),
key: self.current_key.clone(),
})
}
fn resolve(&self, key_id: &str) -> Result<Option<Arc<EncryptionKey>>, KeyProviderError> {
if key_id == &*self.current_id {
return Ok(Some(self.current_key.clone()));
}
Ok(self.historical.get(key_id).cloned())
}
}
fn validate_key_id(id: &str) -> Result<(), KeyProviderError> {
if id.is_empty() || id.contains('.') {
return Err(KeyProviderError::InvalidKeyId(id.to_string()));
}
Ok(())
}
pub struct EncryptedDelegatedCredentialStore<S, K> {
inner: S,
keys: K,
}
impl<S, K> EncryptedDelegatedCredentialStore<S, K>
where
S: DelegatedCredentialStore,
K: KeyProvider,
{
pub fn new(inner: S, keys: K) -> Self {
Self { inner, keys }
}
pub fn inner(&self) -> &S {
&self.inner
}
}
impl<S, K> DelegatedCredentialStore for EncryptedDelegatedCredentialStore<S, K>
where
S: DelegatedCredentialStore,
K: KeyProvider,
{
async fn load(
&self,
tenant: &TenantId,
user: &UserId,
provider: &str,
) -> Result<Option<StoredDelegation>, String> {
let Some(cred) = self.inner.load(tenant, user, provider).await? else {
return Ok(None);
};
let aad_access = build_aad(provider, tenant, user, FIELD_ACCESS);
let access_plain = decrypt_envelope(&self.keys, &cred.access_token, &aad_access)
.map_err(|e| format!("decrypt access_token: {e}"))?;
let refresh_plain = match cred.refresh_token.as_deref() {
Some(rt) => {
let aad_refresh = build_aad(provider, tenant, user, FIELD_REFRESH);
let plain = decrypt_envelope(&self.keys, rt, &aad_refresh)
.map_err(|e| format!("decrypt refresh_token: {e}"))?;
Some(ZeroizedString::from(plain))
}
None => None,
};
Ok(Some(StoredDelegation {
provider: cred.provider,
access_token: ZeroizedString::from(access_plain),
refresh_token: refresh_plain,
expires_at: cred.expires_at,
scopes: cred.scopes,
token_type: cred.token_type,
}))
}
async fn save(
&self,
tenant: &TenantId,
user: &UserId,
credential: StoredDelegation,
) -> Result<(), String> {
let StoredDelegation {
provider,
access_token,
refresh_token,
expires_at,
scopes,
token_type,
} = credential;
let current = self.keys.current().map_err(|e| e.to_string())?;
let aad_access = build_aad(&provider, tenant, user, FIELD_ACCESS);
let access_env = encrypt_envelope(¤t, &access_token, &aad_access)
.map_err(|e| format!("encrypt access_token: {e}"))?;
let refresh_env = match refresh_token.as_deref() {
Some(rt) => {
let aad_refresh = build_aad(&provider, tenant, user, FIELD_REFRESH);
Some(ZeroizedString::from(
encrypt_envelope(¤t, rt, &aad_refresh)
.map_err(|e| format!("encrypt refresh_token: {e}"))?,
))
}
None => None,
};
let wrapped = StoredDelegation {
provider,
access_token: ZeroizedString::from(access_env),
refresh_token: refresh_env,
expires_at,
scopes,
token_type,
};
self.inner.save(tenant, user, wrapped).await
}
async fn revoke(&self, tenant: &TenantId, user: &UserId, provider: &str) -> Result<(), String> {
self.inner.revoke(tenant, user, provider).await
}
}
fn build_aad(provider: &str, tenant: &TenantId, user: &UserId, field: &[u8]) -> Vec<u8> {
let provider_bytes = provider.as_bytes();
let tenant_bytes = tenant.as_bytes();
let user_bytes = user.as_bytes();
let mut buf = Vec::with_capacity(provider_bytes.len() + 16 + 16 + field.len() + 3);
buf.extend_from_slice(provider_bytes);
buf.push(AAD_SEP);
buf.extend_from_slice(tenant_bytes);
buf.push(AAD_SEP);
buf.extend_from_slice(user_bytes);
buf.push(AAD_SEP);
buf.extend_from_slice(field);
buf
}
#[derive(Debug, thiserror::Error)]
enum EnvelopeError {
#[error("malformed envelope: {0}")]
Malformed(&'static str),
#[error("unknown envelope version {0:?}; store written by a newer axess?")]
UnknownVersion(String),
#[error("unknown key id {0:?}")]
UnknownKeyId(String),
#[error("key provider error: {0}")]
KeyProvider(#[from] KeyProviderError),
#[error("decryption failed (wrong key, corrupted ciphertext, or AAD mismatch)")]
Decrypt,
#[error("base64 decode failed")]
Base64,
#[error("encryption failed")]
Encrypt,
}
fn encrypt_envelope(
current: &CurrentKey,
plaintext: &str,
aad: &[u8],
) -> Result<String, EnvelopeError> {
let cipher =
Aes256Gcm::new_from_slice(current.key.as_array()).map_err(|_| EnvelopeError::Encrypt)?;
let mut nonce_bytes = [0u8; NONCE_LEN];
axess_rng::SystemRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ct = cipher
.encrypt(
nonce,
Payload {
msg: plaintext.as_bytes(),
aad,
},
)
.map_err(|_| EnvelopeError::Encrypt)?;
let mut body = Vec::with_capacity(NONCE_LEN + ct.len());
body.extend_from_slice(&nonce_bytes);
body.extend_from_slice(&ct);
let b64 = URL_SAFE_NO_PAD.encode(&body);
Ok(format!("{ENVELOPE_VERSION}.{}.{}", &*current.key_id, b64))
}
fn decrypt_envelope<K: KeyProvider>(
keys: &K,
envelope: &str,
aad: &[u8],
) -> Result<String, EnvelopeError> {
let mut parts = envelope.splitn(3, '.');
let version = parts.next().ok_or(EnvelopeError::Malformed("no version"))?;
let key_id = parts.next().ok_or(EnvelopeError::Malformed("no key id"))?;
let body_b64 = parts.next().ok_or(EnvelopeError::Malformed("no body"))?;
if version != ENVELOPE_VERSION {
return Err(EnvelopeError::UnknownVersion(version.to_string()));
}
if key_id.is_empty() {
return Err(EnvelopeError::Malformed("empty key id"));
}
let body = URL_SAFE_NO_PAD
.decode(body_b64)
.map_err(|_| EnvelopeError::Base64)?;
if body.len() <= NONCE_LEN {
return Err(EnvelopeError::Malformed("body shorter than nonce"));
}
let (nonce_bytes, ciphertext) = body.split_at(NONCE_LEN);
let key = keys
.resolve(key_id)?
.ok_or_else(|| EnvelopeError::UnknownKeyId(key_id.to_string()))?;
let cipher = Aes256Gcm::new_from_slice(key.as_array()).map_err(|_| EnvelopeError::Decrypt)?;
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad,
},
)
.map_err(|_| EnvelopeError::Decrypt)?;
String::from_utf8(plaintext).map_err(|_| EnvelopeError::Malformed("plaintext not utf-8"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::delegated::stored::credential::MemoryDelegatedCredentialStore;
use chrono::{TimeZone, Utc};
fn sample_tenant() -> TenantId {
TenantId::from_bytes([1u8; 16])
}
fn sample_user() -> UserId {
UserId::from_bytes([2u8; 16])
}
fn sample_credential() -> StoredDelegation {
StoredDelegation {
provider: "gmail".to_string(),
access_token: ZeroizedString::from("at-plaintext"),
refresh_token: Some(ZeroizedString::from("rt-plaintext")),
expires_at: Some(Utc.with_ymd_and_hms(2030, 1, 1, 0, 0, 0).unwrap()),
scopes: vec!["gmail.send".to_string()],
token_type: "Bearer".to_string(),
}
}
fn store_with_key(
key: [u8; 32],
) -> EncryptedDelegatedCredentialStore<MemoryDelegatedCredentialStore, MemoryKeyProvider> {
let keys = MemoryKeyProvider::new("k1", key).expect("valid key id");
EncryptedDelegatedCredentialStore::new(MemoryDelegatedCredentialStore::new(), keys)
}
#[tokio::test]
async fn save_then_load_roundtrips_plaintext() {
let store = store_with_key([7u8; 32]);
let tenant = sample_tenant();
let user = sample_user();
store
.save(&tenant, &user, sample_credential())
.await
.expect("save");
let loaded = store
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
assert_eq!(&*loaded.access_token, "at-plaintext");
assert_eq!(loaded.refresh_token.as_deref(), Some("rt-plaintext"));
assert_eq!(loaded.provider, "gmail");
assert_eq!(loaded.scopes, vec!["gmail.send".to_string()]);
}
#[tokio::test]
async fn inner_store_holds_ciphertext_not_plaintext() {
let inner = MemoryDelegatedCredentialStore::new();
let keys = MemoryKeyProvider::new("k1", [9u8; 32]).expect("valid key");
let store = EncryptedDelegatedCredentialStore::new(inner, keys);
let tenant = sample_tenant();
let user = sample_user();
store
.save(&tenant, &user, sample_credential())
.await
.expect("save");
let raw = store
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("inner load")
.expect("present");
assert!(
(*raw.access_token).starts_with("v1.k1."),
"expected envelope prefix, got: {:?}",
&*raw.access_token
);
assert_ne!(&*raw.access_token, "at-plaintext");
let rt = raw.refresh_token.as_deref().expect("refresh present");
assert!(rt.starts_with("v1.k1."));
assert_ne!(rt, "rt-plaintext");
}
#[tokio::test]
async fn ciphertext_differs_across_saves_with_same_plaintext() {
let store = store_with_key([3u8; 32]);
let tenant = sample_tenant();
let user = sample_user();
store
.save(&tenant, &user, sample_credential())
.await
.expect("first save");
let first = store
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
store
.save(&tenant, &user, sample_credential())
.await
.expect("second save");
let second = store
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
assert_ne!(
&*first.access_token, &*second.access_token,
"two saves with same plaintext must yield different ciphertexts (random nonce)"
);
}
#[tokio::test]
async fn row_swap_attack_fails_on_decrypt() {
let inner = MemoryDelegatedCredentialStore::new();
let keys = MemoryKeyProvider::new("k1", [11u8; 32]).expect("valid key");
let store = EncryptedDelegatedCredentialStore::new(inner, keys);
let tenant_a = TenantId::from_bytes([1u8; 16]);
let tenant_b = TenantId::from_bytes([2u8; 16]);
let user = sample_user();
store
.save(&tenant_a, &user, sample_credential())
.await
.expect("save A");
let stolen_row = store
.inner()
.load(&tenant_a, &user, "gmail")
.await
.expect("load A")
.expect("present");
store
.inner()
.save(&tenant_b, &user, stolen_row)
.await
.expect("inner save B");
let result = store.load(&tenant_b, &user, "gmail").await;
assert!(
result.is_err(),
"row-swap from tenant_a to tenant_b should not decrypt: {result:?}"
);
}
#[tokio::test]
async fn field_swap_attack_fails_on_decrypt() {
let inner = MemoryDelegatedCredentialStore::new();
let keys = MemoryKeyProvider::new("k1", [13u8; 32]).expect("valid key");
let store = EncryptedDelegatedCredentialStore::new(inner, keys);
let tenant = sample_tenant();
let user = sample_user();
store
.save(&tenant, &user, sample_credential())
.await
.expect("save");
let row = store
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("inner load")
.expect("present");
let swapped = StoredDelegation {
access_token: row.refresh_token.clone().expect("rt"),
refresh_token: Some(row.access_token.clone()),
..row
};
store
.inner()
.save(&tenant, &user, swapped)
.await
.expect("inner save");
let result = store.load(&tenant, &user, "gmail").await;
assert!(result.is_err(), "field-swap should not decrypt: {result:?}");
}
#[tokio::test]
async fn decrypt_fails_when_key_id_unknown() {
let tenant = sample_tenant();
let user = sample_user();
let inner = MemoryDelegatedCredentialStore::new();
let writer_keys = MemoryKeyProvider::new("k1", [21u8; 32]).expect("valid");
let writer = EncryptedDelegatedCredentialStore::new(inner, writer_keys);
writer
.save(&tenant, &user, sample_credential())
.await
.expect("save");
let raw = writer
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
let fresh_inner = MemoryDelegatedCredentialStore::new();
fresh_inner.save(&tenant, &user, raw).await.expect("plant");
let reader_keys = MemoryKeyProvider::new("k2", [99u8; 32]).expect("valid");
let reader = EncryptedDelegatedCredentialStore::new(fresh_inner, reader_keys);
let result = reader.load(&tenant, &user, "gmail").await;
assert!(result.is_err(), "unknown key id must surface as error");
}
#[tokio::test]
async fn historical_key_decrypts_pre_rotation_rows() {
let tenant = sample_tenant();
let user = sample_user();
let old_key = [4u8; 32];
let new_key = [5u8; 32];
let inner = MemoryDelegatedCredentialStore::new();
let old_keys = MemoryKeyProvider::new("k1", old_key).expect("valid");
let writer = EncryptedDelegatedCredentialStore::new(inner, old_keys);
writer
.save(&tenant, &user, sample_credential())
.await
.expect("save");
let raw = writer
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
let fresh_inner = MemoryDelegatedCredentialStore::new();
fresh_inner.save(&tenant, &user, raw).await.expect("plant");
let rotated_keys = MemoryKeyProvider::new("k2", new_key)
.expect("valid")
.with_historical("k1", old_key)
.expect("valid");
let store = EncryptedDelegatedCredentialStore::new(fresh_inner, rotated_keys);
let loaded = store
.load(&tenant, &user, "gmail")
.await
.expect("decrypt with historical k1")
.expect("present");
assert_eq!(&*loaded.access_token, "at-plaintext");
store
.save(&tenant, &user, loaded)
.await
.expect("save under k2");
let rewritten = store
.inner()
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
assert!(
(*rewritten.access_token).starts_with("v1.k2."),
"after re-save, row should be under new current key, got: {:?}",
&*rewritten.access_token
);
}
#[tokio::test]
async fn load_missing_credential_returns_none() {
let store = store_with_key([1u8; 32]);
let tenant = sample_tenant();
let user = sample_user();
let result = store.load(&tenant, &user, "gmail").await.expect("load");
assert!(result.is_none());
}
#[tokio::test]
async fn revoke_removes_credential() {
let store = store_with_key([6u8; 32]);
let tenant = sample_tenant();
let user = sample_user();
store
.save(&tenant, &user, sample_credential())
.await
.expect("save");
store.revoke(&tenant, &user, "gmail").await.expect("revoke");
let result = store.load(&tenant, &user, "gmail").await.expect("load");
assert!(result.is_none());
}
#[tokio::test]
async fn missing_refresh_token_roundtrips_as_none() {
let store = store_with_key([8u8; 32]);
let tenant = sample_tenant();
let user = sample_user();
let mut cred = sample_credential();
cred.refresh_token = None;
store.save(&tenant, &user, cred).await.expect("save");
let loaded = store
.load(&tenant, &user, "gmail")
.await
.expect("load")
.expect("present");
assert!(loaded.refresh_token.is_none());
assert_eq!(&*loaded.access_token, "at-plaintext");
}
#[test]
fn memory_key_provider_rejects_dot_in_key_id() {
let err = MemoryKeyProvider::new("has.dot", [0u8; 32]).unwrap_err();
assert!(matches!(err, KeyProviderError::InvalidKeyId(_)));
}
#[test]
fn memory_key_provider_rejects_empty_key_id() {
let err = MemoryKeyProvider::new("", [0u8; 32]).unwrap_err();
assert!(matches!(err, KeyProviderError::InvalidKeyId(_)));
}
#[tokio::test]
async fn unknown_envelope_version_rejects_cleanly() {
let keys = MemoryKeyProvider::new("k1", [0u8; 32]).expect("valid");
let fake_envelope = format!("v99.k1.{}", URL_SAFE_NO_PAD.encode([0u8; 32]));
let aad = build_aad("gmail", &sample_tenant(), &sample_user(), FIELD_ACCESS);
let err = decrypt_envelope(&keys, &fake_envelope, &aad).unwrap_err();
assert!(
matches!(err, EnvelopeError::UnknownVersion(ref v) if v == "v99"),
"expected UnknownVersion, got {err:?}"
);
}
#[tokio::test]
async fn short_envelope_body_rejects_as_malformed() {
let keys = MemoryKeyProvider::new("k1", [0u8; 32]).expect("valid");
let short = format!("v1.k1.{}", URL_SAFE_NO_PAD.encode([0u8; NONCE_LEN]));
let aad = build_aad("gmail", &sample_tenant(), &sample_user(), FIELD_ACCESS);
let err = decrypt_envelope(&keys, &short, &aad).unwrap_err();
assert!(matches!(err, EnvelopeError::Malformed(_)));
}
}