use crate::{CliError, CliResult};
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{XChaCha20Poly1305, XNonce};
use rand::rngs::OsRng;
use rand::RngCore;
use std::path::{Path, PathBuf};
pub const CUSTODY_CIPHER_ID: &str = "xchacha20poly1305";
pub const CUSTODY_KEY_VERSION: i64 = 1;
const KEY_LEN: usize = 32;
const NONCE_LEN: usize = 24;
fn aad(audit_id: &str, payload_hash: &str) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(b"zynk-custody-v1");
out.extend_from_slice(&(audit_id.len() as u32).to_le_bytes());
out.extend_from_slice(audit_id.as_bytes());
out.extend_from_slice(&(payload_hash.len() as u32).to_le_bytes());
out.extend_from_slice(payload_hash.as_bytes());
out
}
pub fn encrypt(
key: &[u8; KEY_LEN],
audit_id: &str,
payload_hash: &str,
plaintext: &[u8],
) -> CliResult<(Vec<u8>, Vec<u8>)> {
let cipher = XChaCha20Poly1305::new(key.into());
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let aad = aad(audit_id, payload_hash);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.map_err(|_| CliError::failure("custody encryption failed"))?;
Ok((ciphertext, nonce_bytes.to_vec()))
}
pub fn decrypt(
key: &[u8; KEY_LEN],
audit_id: &str,
payload_hash: &str,
nonce: &[u8],
ciphertext: &[u8],
) -> CliResult<Vec<u8>> {
if nonce.len() != NONCE_LEN {
return Err(CliError::failure("custody nonce has the wrong length"));
}
let cipher = XChaCha20Poly1305::new(key.into());
let nonce = XNonce::from_slice(nonce);
let aad = aad(audit_id, payload_hash);
cipher
.decrypt(
nonce,
Payload {
msg: ciphertext,
aad: &aad,
},
)
.map_err(|_| {
CliError::failure("custody decryption failed (tamper, wrong key, or AAD mismatch)")
})
}
pub fn resolve_key_path(flag: Option<&Path>, db_path: &Path) -> PathBuf {
if let Some(p) = flag {
return p.to_path_buf();
}
if let Ok(env) = std::env::var("ZYNK_CUSTODY_KEY_FILE") {
if !env.is_empty() {
return PathBuf::from(env);
}
}
db_path
.parent()
.map(|d| d.join("custody.key"))
.unwrap_or_else(|| PathBuf::from(".zynk/custody.key"))
}
pub fn load_or_create_key(path: &Path) -> CliResult<[u8; KEY_LEN]> {
if path.exists() {
return load_existing_key(path);
}
if let Some(dir) = path.parent() {
std::fs::create_dir_all(dir)
.map_err(|e| CliError::failure(format!("failed to create custody key dir: {e}")))?;
}
let mut key = [0u8; KEY_LEN];
OsRng.fill_bytes(&mut key);
write_key_0600(path, &key)?;
Ok(key)
}
pub fn load_existing_key(path: &Path) -> CliResult<[u8; KEY_LEN]> {
if !path.exists() {
return Err(CliError::failure(format!(
"custody key file not found: {}",
path.display()
)));
}
check_perms_0600(path)?;
let bytes = std::fs::read(path)
.map_err(|e| CliError::failure(format!("failed to read custody key: {e}")))?;
if bytes.len() != KEY_LEN {
return Err(CliError::usage(format!(
"custody key file must be exactly {KEY_LEN} bytes (got {})",
bytes.len()
)));
}
let mut key = [0u8; KEY_LEN];
key.copy_from_slice(&bytes);
Ok(key)
}
#[cfg(unix)]
fn check_perms_0600(path: &Path) -> CliResult<()> {
use std::os::unix::fs::PermissionsExt;
let mode = std::fs::metadata(path)
.map_err(|e| CliError::failure(format!("failed to stat custody key: {e}")))?
.permissions()
.mode();
if mode & 0o077 != 0 {
return Err(CliError::usage(format!(
"custody key {} is group/world-accessible (mode {:o}); chmod 600 it",
path.display(),
mode & 0o777
)));
}
Ok(())
}
#[cfg(not(unix))]
fn check_perms_0600(_path: &Path) -> CliResult<()> {
Ok(())
}
#[cfg(unix)]
fn write_key_0600(path: &Path, key: &[u8; KEY_LEN]) -> CliResult<()> {
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let mut f = std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.mode(0o600)
.open(path)
.map_err(|e| CliError::failure(format!("failed to create custody key: {e}")))?;
f.write_all(key)
.map_err(|e| CliError::failure(format!("failed to write custody key: {e}")))?;
Ok(())
}
#[cfg(not(unix))]
fn write_key_0600(path: &Path, key: &[u8; KEY_LEN]) -> CliResult<()> {
std::fs::write(path, key)
.map_err(|e| CliError::failure(format!("failed to write custody key: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trips() {
let key = [7u8; 32];
let (ct, nonce) = encrypt(&key, "aud-1", "sha256:ab", b"secret").unwrap();
let pt = decrypt(&key, "aud-1", "sha256:ab", &nonce, &ct).unwrap();
assert_eq!(pt, b"secret");
}
#[test]
fn aad_move_rejected() {
let key = [7u8; 32];
let (ct, nonce) = encrypt(&key, "aud-1", "sha256:ab", b"secret").unwrap();
assert!(decrypt(&key, "aud-2", "sha256:ab", &nonce, &ct).is_err());
assert!(decrypt(&key, "aud-1", "sha256:cd", &nonce, &ct).is_err());
}
#[test]
fn tamper_rejected() {
let key = [7u8; 32];
let (mut ct, nonce) = encrypt(&key, "aud-1", "sha256:ab", b"secret").unwrap();
ct[0] ^= 0xff;
assert!(decrypt(&key, "aud-1", "sha256:ab", &nonce, &ct).is_err());
}
#[test]
fn wrong_key_rejected() {
let key_a = [7u8; 32];
let key_b = [8u8; 32];
let (ct, nonce) = encrypt(&key_a, "aud-1", "sha256:ab", b"secret").unwrap();
let result = decrypt(&key_b, "aud-1", "sha256:ab", &nonce, &ct);
assert!(result.is_err(), "a wrong key must fail to decrypt");
}
#[test]
fn garbage_ciphertext_no_panic() {
let key = [7u8; 32];
let nonce = [9u8; NONCE_LEN];
assert!(decrypt(&key, "aud-1", "sha256:ab", &nonce, &[0u8; 4]).is_err());
assert!(decrypt(&key, "aud-1", "sha256:ab", &nonce, &[]).is_err());
assert!(decrypt(&key, "aud-1", "sha256:ab", &nonce, &[0xABu8; 40]).is_err());
}
#[cfg(unix)]
#[test]
fn key_file_rejects_world_readable() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("custody.key");
let _ = load_or_create_key(&path).unwrap();
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o644)).unwrap();
assert!(load_existing_key(&path).is_err());
}
#[cfg(unix)]
#[test]
fn key_file_rejects_wrong_size() {
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("custody.key");
let mut f = std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.mode(0o600)
.open(&path)
.unwrap();
f.write_all(&[0u8; 10]).unwrap();
drop(f);
assert!(load_existing_key(&path).is_err());
}
}