use std::time::{SystemTime, UNIX_EPOCH};
use aes_gcm::{
aead::{Aead, KeyInit, Payload},
Aes256Gcm, Nonce as AesNonce,
};
use chacha20poly1305::{ChaCha20Poly1305, Nonce as ChachaNonce};
use hkdf::Hkdf;
use rand::{rngs::OsRng, RngCore};
use sha2::{Digest, Sha256};
const ENVELOPE_MARKER: [u8; 2] = [0xFF, 0x20];
const NONCE_LEN: usize = 12;
const MIN_ENCRYPTED_LEN: usize = 2 + 1 + NONCE_LEN + 16;
const HKDF_INFO: &[u8] = b"wireband_frame_key_v1";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlgoId {
Aes256Gcm = 0x01,
ChaCha20Poly1305 = 0x02,
}
impl AlgoId {
pub fn name(&self) -> &'static str {
match self {
Self::Aes256Gcm => "aes256gcm",
Self::ChaCha20Poly1305 => "chacha20poly1305",
}
}
fn from_byte(b: u8) -> Option<Self> {
match b {
0x01 => Some(Self::Aes256Gcm),
0x02 => Some(Self::ChaCha20Poly1305),
_ => None,
}
}
}
#[derive(Clone)]
pub struct EnvelopeCipher {
algo: AlgoId,
key: [u8; 32],
fingerprint: String, }
impl std::fmt::Debug for EnvelopeCipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnvelopeCipher")
.field("algo", &self.algo)
.field("fingerprint", &self.fingerprint)
.finish()
}
}
impl EnvelopeCipher {
pub fn new(key: [u8; 32], algo: AlgoId) -> Self {
let hash = Sha256::digest(key);
let fingerprint = hex_encode_lower(&hash[..4]);
Self { algo, key, fingerprint }
}
pub fn from_hex(hex: &str, algo: AlgoId) -> Result<Self, CryptoError> {
let key = parse_hex32(hex)?;
Ok(Self::new(key, algo))
}
pub fn fingerprint(&self) -> &str {
&self.fingerprint
}
pub fn algo(&self) -> AlgoId {
self.algo
}
pub fn encrypt(&self, payload: &[u8], aad: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let ct = match self.algo {
AlgoId::Aes256Gcm => {
let cipher = Aes256Gcm::new_from_slice(&self.key)
.map_err(|e| CryptoError::Init(e.to_string()))?;
let nonce = AesNonce::from_slice(&nonce_bytes);
cipher.encrypt(nonce, Payload { msg: payload, aad })
.map_err(|e| CryptoError::Encrypt(e.to_string()))?
}
AlgoId::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|e| CryptoError::Init(e.to_string()))?;
let nonce = ChachaNonce::from_slice(&nonce_bytes);
cipher.encrypt(nonce, Payload { msg: payload, aad })
.map_err(|e| CryptoError::Encrypt(e.to_string()))?
}
};
let mut out = Vec::with_capacity(ENVELOPE_MARKER.len() + 1 + NONCE_LEN + ct.len());
out.extend_from_slice(&ENVELOPE_MARKER);
out.push(self.algo as u8);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ct);
Ok(out)
}
pub fn decrypt(&self, data: &[u8], aad: &[u8]) -> Result<Vec<u8>, CryptoError> {
if data.len() < MIN_ENCRYPTED_LEN {
return Err(CryptoError::TooShort(data.len()));
}
if data[..2] != ENVELOPE_MARKER {
return Err(CryptoError::BadMarker);
}
let algo_byte = data[2];
if AlgoId::from_byte(algo_byte) != Some(self.algo) {
return Err(CryptoError::AlgoMismatch {
expected: self.algo as u8,
got: algo_byte,
});
}
let nonce_bytes = &data[3..3 + NONCE_LEN];
let ct = &data[3 + NONCE_LEN..];
match self.algo {
AlgoId::Aes256Gcm => {
let cipher = Aes256Gcm::new_from_slice(&self.key)
.map_err(|e| CryptoError::Init(e.to_string()))?;
let nonce = AesNonce::from_slice(nonce_bytes);
cipher.decrypt(nonce, Payload { msg: ct, aad })
.map_err(|_| CryptoError::AuthFailed)
}
AlgoId::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|e| CryptoError::Init(e.to_string()))?;
let nonce = ChachaNonce::from_slice(nonce_bytes);
cipher.decrypt(nonce, Payload { msg: ct, aad })
.map_err(|_| CryptoError::AuthFailed)
}
}
}
}
#[derive(Clone)]
pub struct SymbolRemapper {
fwd: Vec<u16>,
inv: Vec<u16>,
fingerprint: String,
}
impl std::fmt::Debug for SymbolRemapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymbolRemapper")
.field("fingerprint", &self.fingerprint)
.finish()
}
}
impl SymbolRemapper {
pub fn new(secret: [u8; 32]) -> Self {
let (fwd, inv) = Self::derive_tables(&secret);
let hash = Sha256::digest(secret);
let fingerprint = hex_encode_lower(&hash[..4]);
Self { fwd, inv, fingerprint }
}
pub fn from_hex(hex: &str) -> Result<Self, CryptoError> {
Ok(Self::new(parse_hex32(hex)?))
}
pub fn fingerprint(&self) -> &str {
&self.fingerprint
}
pub fn remap(&self, symbol: u16) -> u16 {
self.fwd[symbol as usize]
}
pub fn unmap(&self, symbol: u16) -> u16 {
self.inv[symbol as usize]
}
fn derive_tables(secret: &[u8; 32]) -> (Vec<u16>, Vec<u16>) {
let needed = 65536 * 4;
let mut prng: Vec<u8> = Vec::with_capacity(needed);
let mut counter: u32 = 0;
while prng.len() < needed {
let mut h = Sha256::new();
h.update(secret);
h.update(counter.to_be_bytes());
prng.extend_from_slice(&h.finalize());
counter += 1;
}
let mut table: Vec<u16> = (0..=65535u16).collect();
for i in (1..=65535usize).rev() {
let offset = (65535 - i) * 4;
let r = u32::from_be_bytes(prng[offset..offset + 4].try_into().unwrap()) as usize;
let j = r % (i + 1);
table.swap(i, j);
}
let mut inv = vec![0u16; 65536];
for (i, &v) in table.iter().enumerate() {
inv[v as usize] = i as u16;
}
(table, inv)
}
}
#[derive(Clone)]
pub struct ContextualKeyDeriver {
master_key: [u8; 32],
window_seconds: u64,
}
impl std::fmt::Debug for ContextualKeyDeriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextualKeyDeriver")
.field("window_seconds", &self.window_seconds)
.finish()
}
}
impl ContextualKeyDeriver {
pub fn new(master_key: [u8; 32], window_seconds: u64) -> Self {
Self { master_key, window_seconds }
}
pub fn from_hex(hex: &str, window_seconds: u64) -> Result<Self, CryptoError> {
Ok(Self::new(parse_hex32(hex)?, window_seconds))
}
pub fn window_seconds(&self) -> u64 {
self.window_seconds
}
pub fn derive_key(&self, sym_hi: u8, sym_lo: u8, at_secs: Option<u64>) -> [u8; 32] {
let now = at_secs.unwrap_or_else(unix_secs);
let window = now / self.window_seconds;
let mut h = Sha256::new();
h.update([sym_hi, sym_lo]);
h.update(window.to_be_bytes());
let salt: [u8; 32] = h.finalize().into();
let hk = Hkdf::<Sha256>::new(Some(&salt), &self.master_key);
let mut okm = [0u8; 32];
hk.expand(HKDF_INFO, &mut okm).expect("HKDF expand (L=32 ≤ HashLen=32)");
okm
}
}
#[derive(Debug, Clone)]
pub struct CryptoContext {
pub cipher: Option<EnvelopeCipher>,
pub remapper: Option<SymbolRemapper>,
pub key_deriver: Option<ContextualKeyDeriver>,
}
impl CryptoContext {
pub fn is_active(&self) -> bool {
self.cipher.is_some() || self.remapper.is_some()
}
pub fn encrypt_frame(&self, raw: &[u8]) -> Result<Vec<u8>, CryptoError> {
if raw.len() < 2 {
return Ok(raw.to_vec());
}
let mut sym_hi = raw[0];
let mut sym_lo = raw[1];
let mut payload = raw[2..].to_vec();
if let Some(ref cipher) = self.cipher {
let aad = [sym_hi, sym_lo];
let c = self.cipher_for_symbol(sym_hi, sym_lo, None)?;
payload = c.unwrap_or(cipher.clone()).encrypt(&payload, &aad)?;
}
if let Some(ref remapper) = self.remapper {
let canonical = ((sym_hi as u16) << 8) | sym_lo as u16;
let ciphered = remapper.remap(canonical);
sym_hi = (ciphered >> 8) as u8;
sym_lo = (ciphered & 0xFF) as u8;
}
let mut out = Vec::with_capacity(2 + payload.len());
out.push(sym_hi);
out.push(sym_lo);
out.extend_from_slice(&payload);
Ok(out)
}
pub fn decrypt_frame(&self, raw: &[u8]) -> Result<Vec<u8>, CryptoError> {
if raw.len() < 2 {
return Ok(raw.to_vec());
}
let mut sym_hi = raw[0];
let mut sym_lo = raw[1];
let mut payload = raw[2..].to_vec();
if let Some(ref remapper) = self.remapper {
let ciphered = ((sym_hi as u16) << 8) | sym_lo as u16;
let canonical = remapper.unmap(ciphered);
sym_hi = (canonical >> 8) as u8;
sym_lo = (canonical & 0xFF) as u8;
}
if self.cipher.is_some() && payload.starts_with(&ENVELOPE_MARKER) {
let aad = [sym_hi, sym_lo];
if let Some(ref deriver) = self.key_deriver {
let now = unix_secs();
let prev = now.saturating_sub(deriver.window_seconds());
let mut last_err = CryptoError::AuthFailed;
for at in [now, prev] {
if let Ok(c) = self.cipher_for_symbol(sym_hi, sym_lo, Some(at)) {
let cipher = c.or_else(|| self.cipher.clone()).unwrap();
match cipher.decrypt(&payload, &aad) {
Ok(pt) => { payload = pt; last_err = CryptoError::AuthFailed; break; }
Err(e) => { last_err = e; }
}
}
}
if !last_err.is_auth_failed_sentinel() {
return Err(last_err);
}
} else {
let cipher = self.cipher.as_ref().unwrap();
payload = cipher.decrypt(&payload, &aad)?;
}
}
let mut out = Vec::with_capacity(2 + payload.len());
out.push(sym_hi);
out.push(sym_lo);
out.extend_from_slice(&payload);
Ok(out)
}
fn cipher_for_symbol(
&self,
sym_hi: u8,
sym_lo: u8,
at_secs: Option<u64>,
) -> Result<Option<EnvelopeCipher>, CryptoError> {
let base = match &self.cipher {
None => return Ok(None),
Some(c) => c,
};
let deriver = match &self.key_deriver {
None => return Ok(Some(base.clone())),
Some(d) => d,
};
let frame_key = deriver.derive_key(sym_hi, sym_lo, at_secs);
Ok(Some(EnvelopeCipher::new(frame_key, base.algo())))
}
pub fn from_env() -> Result<Self, CryptoError> {
let mut cipher: Option<EnvelopeCipher> = None;
let mut remapper: Option<SymbolRemapper> = None;
let mut key_deriver: Option<ContextualKeyDeriver> = None;
if let Ok(hex) = std::env::var("THETA_ENVELOPE_KEY") {
let algo_str = std::env::var("THETA_ENVELOPE_ALGO")
.unwrap_or_else(|_| "aes256gcm".into());
let algo = if algo_str.contains("chacha") {
AlgoId::ChaCha20Poly1305
} else {
AlgoId::Aes256Gcm
};
cipher = Some(EnvelopeCipher::from_hex(&hex, algo)?);
let salt_flag = std::env::var("THETA_CONTEXTUAL_SALT")
.unwrap_or_default()
.to_lowercase();
if matches!(salt_flag.as_str(), "1" | "true" | "yes") {
let window: u64 = std::env::var("THETA_CONTEXTUAL_SALT_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3600);
key_deriver = Some(ContextualKeyDeriver::from_hex(&hex, window)?);
}
}
if let Ok(hex) = std::env::var("THETA_SYMBOL_REMAP_KEY") {
remapper = Some(SymbolRemapper::from_hex(&hex)?);
}
Ok(Self { cipher, remapper, key_deriver })
}
pub fn builder() -> CryptoContextBuilder {
CryptoContextBuilder::default()
}
}
#[derive(Default)]
pub struct CryptoContextBuilder {
cipher: Option<EnvelopeCipher>,
remapper: Option<SymbolRemapper>,
key_deriver: Option<ContextualKeyDeriver>,
}
impl CryptoContextBuilder {
pub fn envelope_key(mut self, key: [u8; 32], algo: AlgoId) -> Self {
self.cipher = Some(EnvelopeCipher::new(key, algo));
self
}
pub fn remap_key(mut self, secret: [u8; 32]) -> Self {
self.remapper = Some(SymbolRemapper::new(secret));
self
}
pub fn hkdf_window(mut self, master_key: [u8; 32], window_seconds: u64) -> Self {
self.key_deriver = Some(ContextualKeyDeriver::new(master_key, window_seconds));
self
}
pub fn build(self) -> CryptoContext {
CryptoContext {
cipher: self.cipher,
remapper: self.remapper,
key_deriver: self.key_deriver,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CryptoError {
#[error("Invalid hex key: {0}")]
BadHex(String),
#[error("Cipher initialisation failed: {0}")]
Init(String),
#[error("Encryption failed: {0}")]
Encrypt(String),
#[error("Authentication failed (wrong key or tampered data)")]
AuthFailed,
#[error("Encrypted region too short: {0} bytes")]
TooShort(usize),
#[error("Missing envelope marker")]
BadMarker,
#[error("Algo mismatch: expected 0x{expected:02X}, got 0x{got:02X}")]
AlgoMismatch { expected: u8, got: u8 },
}
impl CryptoError {
fn is_auth_failed_sentinel(&self) -> bool {
matches!(self, CryptoError::AuthFailed)
}
}
fn unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn parse_hex32(hex: &str) -> Result<[u8; 32], CryptoError> {
if hex.len() != 64 {
return Err(CryptoError::BadHex(format!(
"expected 64 hex chars (32 bytes), got {}",
hex.len()
)));
}
let mut out = [0u8; 32];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let s = std::str::from_utf8(chunk)
.map_err(|e| CryptoError::BadHex(e.to_string()))?;
out[i] = u8::from_str_radix(s, 16)
.map_err(|e| CryptoError::BadHex(e.to_string()))?;
}
Ok(out)
}
fn hex_encode_lower(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key() -> [u8; 32] { [0x42u8; 32] }
fn test_secret() -> [u8; 32] { [0x7Fu8; 32] }
#[test]
fn envelope_cipher_aes_round_trip() {
let c = EnvelopeCipher::new(test_key(), AlgoId::Aes256Gcm);
let aad = [0xFC, 0x60];
let plain = b"hello world";
let enc = c.encrypt(plain, &aad).unwrap();
let dec = c.decrypt(&enc, &aad).unwrap();
assert_eq!(dec, plain);
}
#[test]
fn envelope_cipher_chacha_round_trip() {
let c = EnvelopeCipher::new(test_key(), AlgoId::ChaCha20Poly1305);
let plain = b"chacha test";
let enc = c.encrypt(plain, b"").unwrap();
let dec = c.decrypt(&enc, b"").unwrap();
assert_eq!(dec, plain);
}
#[test]
fn envelope_cipher_aad_binding() {
let c = EnvelopeCipher::new(test_key(), AlgoId::Aes256Gcm);
let enc = c.encrypt(b"data", &[0xFC, 0x60]).unwrap();
assert!(c.decrypt(&enc, &[0xFC, 0x61]).is_err());
}
#[test]
fn symbol_remapper_bijection() {
let r = SymbolRemapper::new(test_secret());
for sym in [0u16, 1, 0xFC60, 0x00FF, 0xFFFF] {
assert_eq!(r.unmap(r.remap(sym)), sym);
assert_eq!(r.remap(r.unmap(sym)), sym);
}
}
#[test]
fn symbol_remapper_is_permutation() {
let r = SymbolRemapper::new(test_secret());
let mapped: Vec<u16> = (0u16..=255).map(|s| r.remap(s)).collect();
let mut sorted = mapped.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), 256); }
#[test]
fn contextual_key_deriver_deterministic() {
let d = ContextualKeyDeriver::new(test_key(), 3600);
let k1 = d.derive_key(0xFC, 0x60, Some(1_000_000));
let k2 = d.derive_key(0xFC, 0x60, Some(1_000_000));
assert_eq!(k1, k2);
}
#[test]
fn contextual_key_deriver_window_rotation() {
let d = ContextualKeyDeriver::new(test_key(), 3600);
let k1 = d.derive_key(0xFC, 0x60, Some(0));
let k2 = d.derive_key(0xFC, 0x60, Some(3600));
assert_ne!(k1, k2, "different windows must yield different keys");
}
#[test]
fn contextual_key_deriver_symbol_binding() {
let d = ContextualKeyDeriver::new(test_key(), 3600);
let k1 = d.derive_key(0xFC, 0x60, Some(0));
let k2 = d.derive_key(0xFC, 0x61, Some(0));
assert_ne!(k1, k2, "different symbols must yield different keys");
}
#[test]
fn crypto_context_all_layers_round_trip() {
let ctx = CryptoContext::builder()
.envelope_key(test_key(), AlgoId::Aes256Gcm)
.remap_key(test_secret())
.hkdf_window(test_key(), 3600)
.build();
let frame = vec![0xFC, 0x60, b'{', b'"', b'v', b'"', b':', b'1', b'}'];
let enc = ctx.encrypt_frame(&frame).unwrap();
assert_ne!(enc, frame);
let dec = ctx.decrypt_frame(&enc).unwrap();
assert_eq!(dec, frame);
}
#[test]
fn crypto_context_cipher_only() {
let ctx = CryptoContext::builder()
.envelope_key(test_key(), AlgoId::Aes256Gcm)
.build();
let frame = vec![0xF2, 0x10, b'{', b'}'];
let enc = ctx.encrypt_frame(&frame).unwrap();
let dec = ctx.decrypt_frame(&enc).unwrap();
assert_eq!(dec, frame);
}
#[test]
fn crypto_context_remapper_only() {
let ctx = CryptoContext::builder()
.remap_key(test_secret())
.build();
let frame = vec![0xFC, 0x60, b'{', b'}'];
let enc = ctx.encrypt_frame(&frame).unwrap();
let r = SymbolRemapper::new(test_secret());
assert_eq!(enc[0], (r.remap(0xFC60) >> 8) as u8);
assert_eq!(enc[1], (r.remap(0xFC60) & 0xFF) as u8);
let dec = ctx.decrypt_frame(&enc).unwrap();
assert_eq!(dec, frame);
}
#[test]
fn crypto_context_inactive_passthrough() {
let ctx = CryptoContext { cipher: None, remapper: None, key_deriver: None };
let frame = vec![0xFC, 0x60, b'{', b'}'];
assert_eq!(ctx.encrypt_frame(&frame).unwrap(), frame);
assert_eq!(ctx.decrypt_frame(&frame).unwrap(), frame);
}
}