use base64::Engine;
use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
Key, XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
const KEY_LEN: usize = 32;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
const VERSION_LEN: usize = 1;
const MIN_TOKEN_LEN: usize = VERSION_LEN + NONCE_LEN + TAG_LEN;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SealError;
impl std::fmt::Display for SealError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("token verification failed")
}
}
impl std::error::Error for SealError {}
pub fn normalize_key(key: &[u8]) -> [u8; KEY_LEN] {
if key.len() == KEY_LEN {
let mut out = [0u8; KEY_LEN];
out.copy_from_slice(key);
return out;
}
let digest = Sha256::digest(key);
let mut out = [0u8; KEY_LEN];
out.copy_from_slice(&digest);
out
}
pub fn seal_bytes(payload: &[u8], key: &[u8], aad: &[u8], version: u8) -> Vec<u8> {
let normalized = normalize_key(key);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let cipher = XChaCha20Poly1305::new(Key::from_slice(&normalized));
let ciphertext = cipher
.encrypt(
XNonce::from_slice(&nonce_bytes),
Payload { msg: payload, aad },
)
.expect("XChaCha20-Poly1305 encrypt cannot fail for in-memory plaintext");
let mut wire = Vec::with_capacity(VERSION_LEN + NONCE_LEN + ciphertext.len());
wire.push(version);
wire.extend_from_slice(&nonce_bytes);
wire.extend_from_slice(&ciphertext);
wire
}
pub fn open_bytes(token: &[u8], key: &[u8], aad: &[u8], version: u8) -> Result<Vec<u8>, SealError> {
if token.len() < MIN_TOKEN_LEN || token[0] != version {
return Err(SealError);
}
let normalized = normalize_key(key);
let nonce = &token[VERSION_LEN..VERSION_LEN + NONCE_LEN];
let ciphertext = &token[VERSION_LEN + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(Key::from_slice(&normalized));
cipher
.decrypt(
XNonce::from_slice(nonce),
Payload {
msg: ciphertext,
aad,
},
)
.map_err(|_| SealError)
}
pub fn seal_base64(payload: &[u8], key: &[u8], aad: &[u8], version: u8) -> String {
base64::engine::general_purpose::STANDARD.encode(seal_bytes(payload, key, aad, version))
}
pub fn open_base64(token: &str, key: &[u8], aad: &[u8], version: u8) -> Result<Vec<u8>, SealError> {
let raw = base64::engine::general_purpose::STANDARD
.decode(token.as_bytes())
.map_err(|_| SealError)?;
open_bytes(&raw, key, aad, version)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip() {
let key = [7u8; 32];
let sealed = seal_bytes(b"hello world", &key, b"aad", 4);
assert_eq!(sealed[0], 4);
let opened = open_bytes(&sealed, &key, b"aad", 4).unwrap();
assert_eq!(opened, b"hello world");
}
#[test]
fn wrong_key_aad_version_all_fail_uniformly() {
let key = [7u8; 32];
let sealed = seal_bytes(b"payload", &key, b"aad", 4);
assert_eq!(open_bytes(&sealed, &[9u8; 32], b"aad", 4), Err(SealError));
assert_eq!(open_bytes(&sealed, &key, b"other", 4), Err(SealError));
assert_eq!(open_bytes(&sealed, &key, b"aad", 5), Err(SealError));
assert_eq!(open_bytes(b"short", &key, b"aad", 4), Err(SealError));
let mut tampered = sealed.clone();
*tampered.last_mut().unwrap() ^= 0x01;
assert_eq!(open_bytes(&tampered, &key, b"aad", 4), Err(SealError));
}
#[test]
fn normalize_key_passthrough_and_hash() {
let exact = [3u8; 32];
assert_eq!(normalize_key(&exact), exact);
assert_eq!(normalize_key(b"short"), normalize_key(b"short"));
assert_ne!(normalize_key(b"short"), normalize_key(b"other"));
}
#[test]
fn base64_helpers_roundtrip() {
let key = b"operator-supplied-key-of-any-length";
let tok = seal_base64(b"state", key, b"id", 4);
assert_eq!(open_base64(&tok, key, b"id", 4).unwrap(), b"state");
assert_eq!(open_base64("!!!notb64!!!", key, b"id", 4), Err(SealError));
}
}