use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng, Payload},
AeadCore, XChaCha20Poly1305, XNonce,
};
use sha2::{Digest, Sha256};
use uuid::Uuid;
use crate::error::RetrievalError;
const NONCE_LEN: usize = 24;
const AAD_MAGIC: &[u8] = b"tt-retrieval:audit_log:v1";
pub struct RetrievalAuditLog {
pool: sqlx::PgPool,
master_key: [u8; 32],
}
impl std::fmt::Debug for RetrievalAuditLog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetrievalAuditLog")
.field("pool", &"PgPool { .. }")
.field("master_key", &"[REDACTED]")
.finish()
}
}
impl RetrievalAuditLog {
pub fn new(pool: sqlx::PgPool, master_key: [u8; 32]) -> Self {
Self { pool, master_key }
}
pub fn from_env(pool: sqlx::PgPool) -> Option<Self> {
let hex_key = std::env::var("TT_MASTER_KEY").ok()?;
let bytes = hex::decode(hex_key.trim()).ok()?;
let master_key: [u8; 32] = bytes.try_into().ok()?;
Some(Self { pool, master_key })
}
pub async fn record(
&self,
org_id: Uuid,
substitutions: u32,
tokens_saved: i64,
original_prompt: &str,
) -> Result<(), RetrievalError> {
let blob = encrypt_prompt(&self.master_key, org_id, original_prompt)?;
sqlx::query(
r#"INSERT INTO retrieval_audit_log
(id, org_id, prompt_enc, substitutions, tokens_saved)
VALUES ($1, $2, $3, $4, $5)"#,
)
.bind(Uuid::new_v4())
.bind(org_id)
.bind(&blob)
.bind(i32::try_from(substitutions).unwrap_or(i32::MAX))
.bind(tokens_saved)
.execute(&self.pool)
.await
.map_err(|e| RetrievalError::Store(format!("audit insert: {e}")))?;
Ok(())
}
pub fn decrypt(&self, org_id: Uuid, blob: &[u8]) -> Result<String, RetrievalError> {
decrypt_prompt(&self.master_key, org_id, blob)
}
}
fn derive_audit_key(master: &[u8; 32], org_id: Uuid) -> [u8; 32] {
let mut h = Sha256::new();
h.update(master);
h.update(b"|tt-retrieval:audit-key:v1|");
h.update(org_id.as_bytes());
h.finalize().into()
}
fn aad(org_id: Uuid) -> Vec<u8> {
let mut buf = Vec::with_capacity(AAD_MAGIC.len() + 16);
buf.extend_from_slice(AAD_MAGIC);
buf.extend_from_slice(org_id.as_bytes());
buf
}
fn encrypt_prompt(master: &[u8; 32], org_id: Uuid, plain: &str) -> Result<Vec<u8>, RetrievalError> {
let key = derive_audit_key(master, org_id);
let cipher = XChaCha20Poly1305::new((&key).into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let ad = aad(org_id);
let ciphertext = cipher
.encrypt(
&nonce,
Payload {
msg: plain.as_bytes(),
aad: &ad,
},
)
.map_err(|_| RetrievalError::Store("audit encrypt failed".into()))?;
let mut blob = Vec::with_capacity(NONCE_LEN + ciphertext.len());
blob.extend_from_slice(&nonce);
blob.extend_from_slice(&ciphertext);
Ok(blob)
}
fn decrypt_prompt(master: &[u8; 32], org_id: Uuid, blob: &[u8]) -> Result<String, RetrievalError> {
if blob.len() < NONCE_LEN {
return Err(RetrievalError::Malformed("audit blob too short".into()));
}
let (nonce_bytes, ciphertext) = blob.split_at(NONCE_LEN);
let nonce = XNonce::from_slice(nonce_bytes);
let key = derive_audit_key(master, org_id);
let cipher = XChaCha20Poly1305::new((&key).into());
let ad = aad(org_id);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad: &ad,
},
)
.map_err(|_| RetrievalError::Store("audit decrypt failed".into()))?;
String::from_utf8(plain).map_err(|_| RetrievalError::Malformed("audit utf8".into()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_then_decrypt_round_trips() {
let master = [5u8; 32];
let org = Uuid::from_u128(7);
let prompt = r#"{"messages":[{"role":"user","content":"secret prompt"}]}"#;
let blob = encrypt_prompt(&master, org, prompt).unwrap();
assert_eq!(decrypt_prompt(&master, org, &blob).unwrap(), prompt);
}
#[test]
fn ciphertext_does_not_contain_plaintext() {
let master = [1u8; 32];
let org = Uuid::from_u128(1);
let blob = encrypt_prompt(&master, org, "api_key=sk-leakme").unwrap();
assert!(!String::from_utf8_lossy(&blob).contains("sk-leakme"));
}
#[test]
fn wrong_org_or_master_fails_to_decrypt() {
let master = [2u8; 32];
let org = Uuid::from_u128(2);
let blob = encrypt_prompt(&master, org, "hello").unwrap();
assert!(decrypt_prompt(&master, Uuid::from_u128(3), &blob).is_err());
assert!(decrypt_prompt(&[9u8; 32], org, &blob).is_err());
}
#[test]
fn truncated_blob_is_rejected() {
assert!(decrypt_prompt(&[0u8; 32], Uuid::nil(), &[0u8; 5]).is_err());
}
}