#![allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KwAlgorithm {
Aes128,
Aes192,
Aes256,
}
impl KwAlgorithm {
pub fn key_size_bytes(&self) -> usize {
match self {
Self::Aes128 => 16,
Self::Aes192 => 24,
Self::Aes256 => 32,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Aes128 => "AES-128",
Self::Aes192 => "AES-192",
Self::Aes256 => "AES-256",
}
}
fn as_u8(self) -> u8 {
match self {
Self::Aes128 => 2,
Self::Aes192 => 3,
Self::Aes256 => 4,
}
}
fn from_u8(v: u8) -> Option<Self> {
match v {
2 => Some(Self::Aes128),
3 => Some(Self::Aes192),
4 => Some(Self::Aes256),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct KeyMaterial {
pub version: u8,
pub pt: u8,
pub sign: u16,
pub kk: u8,
pub keki: u32,
pub cipher: KwAlgorithm,
pub auth: u8,
pub se: u8,
pub salt: Vec<u8>,
pub wrapped_key: Vec<u8>,
}
fn lcg_bytes(seed: u64, count: usize) -> Vec<u8> {
let mut state = seed;
let mut out = Vec::with_capacity(count);
for _ in 0..count {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
out.push((state >> 33) as u8);
}
out
}
fn sim_wrap(key: &[u8], plaintext: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + plaintext.len());
out.extend_from_slice(&[0xA6u8; 8]);
for (i, &b) in plaintext.iter().enumerate() {
out.push(b ^ key[i % key.len()]);
}
out
}
impl KeyMaterial {
pub fn new(algorithm: KwAlgorithm, seed: u64) -> Self {
let key_bytes = algorithm.key_size_bytes();
let salt = lcg_bytes(seed ^ 0xDEAD_BEEF_CAFE_0001, 16);
let sek = lcg_bytes(seed ^ 0x1234_5678_9ABC_DEF0, key_bytes);
let wrapped_key = sim_wrap(&vec![0u8; key_bytes], &sek);
Self {
version: 1,
pt: 2,
sign: 0x2029,
kk: 3, keki: 0,
cipher: algorithm,
auth: 0,
se: 2,
salt,
wrapped_key,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::new();
out.push(self.version);
out.push(self.pt);
out.push((self.sign >> 8) as u8);
out.push((self.sign & 0xFF) as u8);
out.push(self.kk);
out.push((self.keki >> 24) as u8);
out.push((self.keki >> 16) as u8);
out.push((self.keki >> 8) as u8);
out.push((self.keki & 0xFF) as u8);
out.push(self.cipher.as_u8());
out.push(self.auth);
out.push(self.se);
out.push(self.salt.len() as u8);
out.extend_from_slice(&self.salt);
let wk_len = self.wrapped_key.len() as u16;
out.push((wk_len >> 8) as u8);
out.push((wk_len & 0xFF) as u8);
out.extend_from_slice(&self.wrapped_key);
out
}
pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
if data.len() < 15 {
return Err("KeyMaterial data too short".to_string());
}
let version = data[0];
let pt = data[1];
let sign = (u16::from(data[2]) << 8) | u16::from(data[3]);
let kk = data[4];
let keki = (u32::from(data[5]) << 24)
| (u32::from(data[6]) << 16)
| (u32::from(data[7]) << 8)
| u32::from(data[8]);
let cipher_byte = data[9];
let cipher = KwAlgorithm::from_u8(cipher_byte)
.ok_or_else(|| format!("Unknown cipher byte: {cipher_byte}"))?;
let auth = data[10];
let se = data[11];
let salt_len = data[12] as usize;
if data.len() < 13 + salt_len + 2 {
return Err("KeyMaterial truncated at salt".to_string());
}
let salt = data[13..13 + salt_len].to_vec();
let wk_offset = 13 + salt_len;
let wk_len = (u16::from(data[wk_offset]) << 8 | u16::from(data[wk_offset + 1])) as usize;
if data.len() < wk_offset + 2 + wk_len {
return Err("KeyMaterial truncated at wrapped_key".to_string());
}
let wrapped_key = data[wk_offset + 2..wk_offset + 2 + wk_len].to_vec();
Ok(Self {
version,
pt,
sign,
kk,
keki,
cipher,
auth,
se,
salt,
wrapped_key,
})
}
pub fn is_valid(&self) -> bool {
self.sign == 0x2029 && self.version == 1
}
pub fn key_size_bits(&self) -> usize {
self.cipher.key_size_bytes() * 8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HsExtType {
HsReq = 1,
HsRsp = 2,
KmReq = 3,
KmRsp = 4,
Sid = 5,
Group = 6,
}
impl HsExtType {
pub fn from_u16(v: u16) -> Option<Self> {
match v {
1 => Some(Self::HsReq),
2 => Some(Self::HsRsp),
3 => Some(Self::KmReq),
4 => Some(Self::KmRsp),
5 => Some(Self::Sid),
6 => Some(Self::Group),
_ => None,
}
}
pub fn as_u16(&self) -> u16 {
*self as u16
}
}
#[derive(Debug, Clone)]
pub struct HsExtension {
pub ext_type: HsExtType,
pub ext_size: u16,
pub data: Vec<u8>,
}
impl HsExtension {
pub fn key_material(km: &KeyMaterial) -> Self {
let data = km.to_bytes();
let words = data.len().div_ceil(4) as u16;
Self {
ext_type: HsExtType::KmReq,
ext_size: words,
data,
}
}
pub fn stream_id(sid: &str) -> Self {
let data = sid.as_bytes().to_vec();
let words = data.len().div_ceil(4) as u16;
Self {
ext_type: HsExtType::Sid,
ext_size: words,
data,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let padded_len = self.ext_size as usize * 4;
let mut out = Vec::with_capacity(4 + padded_len);
out.push((self.ext_type.as_u16() >> 8) as u8);
out.push((self.ext_type.as_u16() & 0xFF) as u8);
out.push((self.ext_size >> 8) as u8);
out.push((self.ext_size & 0xFF) as u8);
out.extend_from_slice(&self.data);
while out.len() < 4 + padded_len {
out.push(0);
}
out
}
pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
if data.len() < 4 {
return Err("HsExtension too short".to_string());
}
let type_code = (u16::from(data[0]) << 8) | u16::from(data[1]);
let ext_type = HsExtType::from_u16(type_code)
.ok_or_else(|| format!("Unknown HsExtType: {type_code}"))?;
let ext_size = (u16::from(data[2]) << 8) | u16::from(data[3]);
let payload_len = ext_size as usize * 4;
if data.len() < 4 + payload_len {
return Err("HsExtension data truncated".to_string());
}
let payload = data[4..4 + payload_len].to_vec();
Ok(Self {
ext_type,
ext_size,
data: payload,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncryptionState {
NoEncryption,
PendingKeyMaterial,
KeyMaterialSent,
Active,
Rotating,
Failed,
}
pub struct EncryptionSession {
state: EncryptionState,
algorithm: KwAlgorithm,
current_km: Option<KeyMaterial>,
pending_km: Option<KeyMaterial>,
key_rotation_interval_packets: u64,
packets_since_rotation: u64,
}
impl EncryptionSession {
pub fn new(algorithm: KwAlgorithm) -> Self {
Self {
state: EncryptionState::PendingKeyMaterial,
algorithm,
current_km: None,
pending_km: None,
key_rotation_interval_packets: 8_192,
packets_since_rotation: 0,
}
}
pub fn no_encryption() -> Self {
Self {
state: EncryptionState::NoEncryption,
algorithm: KwAlgorithm::Aes128,
current_km: None,
pending_km: None,
key_rotation_interval_packets: u64::MAX,
packets_since_rotation: 0,
}
}
pub fn state(&self) -> EncryptionState {
self.state
}
pub fn is_active(&self) -> bool {
matches!(
self.state,
EncryptionState::Active | EncryptionState::Rotating
)
}
pub fn initiate(&mut self, seed: u64) -> Option<KeyMaterial> {
if self.state == EncryptionState::NoEncryption {
return None;
}
let km = KeyMaterial::new(self.algorithm, seed);
self.pending_km = Some(km.clone());
self.state = EncryptionState::KeyMaterialSent;
Some(km)
}
pub fn apply_peer_km(&mut self, km: KeyMaterial) -> Result<(), String> {
if !km.is_valid() {
self.state = EncryptionState::Failed;
return Err("Peer key material is invalid (bad signature/version)".to_string());
}
self.current_km = Some(km);
self.packets_since_rotation = 0;
self.state = EncryptionState::Active;
Ok(())
}
pub fn should_rotate(&self) -> bool {
self.is_active() && self.packets_since_rotation >= self.key_rotation_interval_packets
}
pub fn record_packet(&mut self) {
self.packets_since_rotation += 1;
}
pub fn rotate(&mut self, seed: u64) -> Option<KeyMaterial> {
if !self.is_active() {
return None;
}
let km = KeyMaterial::new(self.algorithm, seed);
self.pending_km = Some(km.clone());
self.packets_since_rotation = 0;
self.state = EncryptionState::Rotating;
Some(km)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kw_algorithm_key_size() {
assert_eq!(KwAlgorithm::Aes128.key_size_bytes(), 16);
assert_eq!(KwAlgorithm::Aes192.key_size_bytes(), 24);
assert_eq!(KwAlgorithm::Aes256.key_size_bytes(), 32);
}
#[test]
fn test_key_material_new_is_valid() {
let km = KeyMaterial::new(KwAlgorithm::Aes128, 42);
assert!(km.is_valid(), "KeyMaterial::new should produce valid KM");
}
#[test]
fn test_key_material_roundtrip() {
let km = KeyMaterial::new(KwAlgorithm::Aes256, 99);
let bits_before = km.key_size_bits();
let bytes = km.to_bytes();
let km2 = KeyMaterial::from_bytes(&bytes).expect("from_bytes should succeed");
assert_eq!(km2.key_size_bits(), bits_before);
assert!(km2.is_valid());
}
#[test]
fn test_hs_ext_type_roundtrip() {
let types = [
HsExtType::HsReq,
HsExtType::HsRsp,
HsExtType::KmReq,
HsExtType::KmRsp,
HsExtType::Sid,
HsExtType::Group,
];
for t in types {
assert_eq!(HsExtType::from_u16(t.as_u16()), Some(t));
}
}
#[test]
fn test_encryption_session_inactive() {
let session = EncryptionSession::no_encryption();
assert_eq!(session.state(), EncryptionState::NoEncryption);
assert!(!session.is_active());
}
#[test]
fn test_encryption_session_initiate() {
let mut session = EncryptionSession::new(KwAlgorithm::Aes128);
let km = session.initiate(1234);
assert!(km.is_some(), "initiate should return Some(km)");
assert_eq!(session.state(), EncryptionState::KeyMaterialSent);
assert!(km.expect("should succeed in test").is_valid());
}
#[test]
fn test_encryption_rotation() {
let mut session = EncryptionSession::new(KwAlgorithm::Aes128);
let km = KeyMaterial::new(KwAlgorithm::Aes128, 7);
session
.apply_peer_km(km)
.expect("apply_peer_km should succeed");
assert_eq!(session.state(), EncryptionState::Active);
let interval = session.key_rotation_interval_packets;
for _ in 0..interval {
session.record_packet();
}
assert!(
session.should_rotate(),
"should_rotate must be true after interval packets"
);
let new_km = session.rotate(9999);
assert!(new_km.is_some());
assert_eq!(session.state(), EncryptionState::Rotating);
assert!(!session.should_rotate());
}
}