use std::path::Path;
use std::sync::Arc;
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use bytes::Bytes;
use rand::RngCore;
use thiserror::Error;
pub const SSE_MAGIC: &[u8; 4] = b"S4E1";
pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
const KEY_LEN: usize = 32;
#[derive(Debug, Error)]
pub enum SseError {
#[error("SSE key file {path:?}: {source}")]
KeyFileIo {
path: std::path::PathBuf,
source: std::io::Error,
},
#[error(
"SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
)]
BadKeyLength { got: usize },
#[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
TooShort { got: usize },
#[error("SSE bad magic: expected S4E1, got {got:?}")]
BadMagic { got: [u8; 4] },
#[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
UnsupportedAlgo { tag: u8 },
#[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
DecryptFailed,
}
#[derive(Clone)]
pub struct SseKey(Arc<[u8; KEY_LEN]>);
impl SseKey {
pub fn from_path(path: &Path) -> Result<Self, SseError> {
let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
path: path.to_path_buf(),
source,
})?;
Self::from_bytes(&raw)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
if bytes.len() == KEY_LEN {
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(bytes);
return Ok(Self(Arc::new(k)));
}
let s = std::str::from_utf8(bytes).unwrap_or("").trim();
if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
let mut k = [0u8; KEY_LEN];
for (i, k_byte) in k.iter_mut().enumerate() {
*k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
.map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
}
return Ok(Self(Arc::new(k)));
}
if let Ok(decoded) =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
&& decoded.len() == KEY_LEN
{
let mut k = [0u8; KEY_LEN];
k.copy_from_slice(&decoded);
return Ok(Self(Arc::new(k)));
}
Err(SseError::BadKeyLength { got: bytes.len() })
}
fn as_aes_key(&self) -> &Key<Aes256Gcm> {
Key::<Aes256Gcm>::from_slice(self.0.as_ref())
}
}
impl std::fmt::Debug for SseKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseKey")
.field("len", &KEY_LEN)
.field("key", &"<redacted>")
.finish()
}
}
pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
let cipher = Aes256Gcm::new(key.as_aes_key());
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC);
aad[4] = ALGO_AES_256_GCM;
let ct_with_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
debug_assert!(ct_with_tag.len() >= TAG_LEN);
let split = ct_with_tag.len() - TAG_LEN;
let (ct, tag) = ct_with_tag.split_at(split);
let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
out.extend_from_slice(SSE_MAGIC);
out.push(ALGO_AES_256_GCM);
out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(tag);
out.extend_from_slice(ct);
Bytes::from(out)
}
pub fn decrypt(key: &SseKey, body: &[u8]) -> Result<Bytes, SseError> {
if body.len() < SSE_HEADER_BYTES {
return Err(SseError::TooShort { got: body.len() });
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&body[..4]);
if &magic != SSE_MAGIC {
return Err(SseError::BadMagic { got: magic });
}
let algo = body[4];
if algo != ALGO_AES_256_GCM {
return Err(SseError::UnsupportedAlgo { tag: algo });
}
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
let mut tag_bytes = [0u8; TAG_LEN];
tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
let ct = &body[SSE_HEADER_BYTES..];
let cipher = Aes256Gcm::new(key.as_aes_key());
let nonce = Nonce::from_slice(&nonce_bytes);
let mut aad = [0u8; 8];
aad[..4].copy_from_slice(SSE_MAGIC);
aad[4] = ALGO_AES_256_GCM;
let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
ct_with_tag.extend_from_slice(ct);
ct_with_tag.extend_from_slice(&tag_bytes);
let plain = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: &aad,
},
)
.map_err(|_| SseError::DecryptFailed)?;
Ok(Bytes::from(plain))
}
pub fn looks_encrypted(body: &[u8]) -> bool {
body.len() >= SSE_HEADER_BYTES && &body[..4] == SSE_MAGIC
}
pub type SharedSseKey = Arc<SseKey>;
#[cfg(test)]
mod tests {
use super::*;
fn key32() -> SseKey {
SseKey::from_bytes(&[7u8; 32]).unwrap()
}
#[test]
fn roundtrip_basic() {
let k = key32();
let pt = b"the quick brown fox jumps over the lazy dog";
let ct = encrypt(&k, pt);
assert!(looks_encrypted(&ct));
assert_eq!(&ct[..4], SSE_MAGIC);
assert_eq!(ct[4], ALGO_AES_256_GCM);
assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
let pt2 = decrypt(&k, &ct).unwrap();
assert_eq!(pt2.as_ref(), pt);
}
#[test]
fn wrong_key_fails() {
let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
let k2 = SseKey::from_bytes(&[2u8; 32]).unwrap();
let ct = encrypt(&k1, b"secret");
let err = decrypt(&k2, &ct).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn tampered_ciphertext_fails() {
let k = key32();
let mut ct = encrypt(&k, b"secret message").to_vec();
let last = ct.len() - 1;
ct[last] ^= 0x01;
let err = decrypt(&k, &ct).unwrap_err();
assert!(matches!(err, SseError::DecryptFailed));
}
#[test]
fn tampered_algo_byte_fails() {
let k = key32();
let mut ct = encrypt(&k, b"secret").to_vec();
ct[4] = 99; let err = decrypt(&k, &ct).unwrap_err();
assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
}
#[test]
fn rejects_short_body() {
let k = key32();
let err = decrypt(&k, b"short").unwrap_err();
assert!(matches!(err, SseError::TooShort { got: 5 }));
}
#[test]
fn looks_encrypted_passthrough_returns_false() {
assert!(!looks_encrypted(b"S4F2\x01\x00\x00\x00........"));
assert!(!looks_encrypted(b""));
}
#[test]
fn key_from_hex_string() {
let k =
SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
.unwrap_err();
assert!(matches!(k, SseError::BadKeyLength { .. }));
let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
}
#[test]
fn encrypt_uses_random_nonce() {
let k = key32();
let pt = b"deterministic input";
let a = encrypt(&k, pt);
let b = encrypt(&k, pt);
assert_ne!(a, b, "nonce must be random per-call");
}
}