pub mod key;
pub mod stream;
mod wrap;
use std::fmt::Debug;
use bytes::BytesMut;
use crate::{packet::*, settings::*};
use stream::KeyMaterialError;
#[derive(Debug, Eq, PartialEq)]
pub enum DecryptionError {
UnexpectedUnencryptedPacket(DataPacket),
UnexpectedEncryptedPacket(DataPacket),
EncryptionFailure,
DecryptionFailure,
}
#[derive(Debug)]
pub struct Decryption(Option<(StreamEncryptionKeys, KeySettings)>);
impl Decryption {
pub fn new(settings: Option<CipherSettings>) -> Self {
Self(settings.map(|settings| (settings.stream_keys, settings.key_settings)))
}
pub fn decrypt(&self, packet: DataPacket) -> Result<(usize, DataPacket), DecryptionError> {
use DecryptionError::*;
let mut packet = packet;
match (packet.encryption, &self.0) {
(DataEncryption::None, None) => Ok((0, packet)),
(DataEncryption::None, Some(_)) => Err(UnexpectedUnencryptedPacket(packet)),
(DataEncryption::Even | DataEncryption::Odd, None) => {
Err(UnexpectedEncryptedPacket(packet))
}
(selected_sek, Some((stream_keys, _))) => {
let mut data = BytesMut::with_capacity(packet.payload.len());
data.extend_from_slice(&packet.payload[..]);
let bytes = stream_keys
.decrypt(selected_sek, packet.seq_number, &mut data)
.ok_or(DecryptionFailure)?;
packet.encryption = DataEncryption::None;
packet.payload = data.freeze();
Ok((bytes, packet))
}
}
}
pub fn refresh_key_material(
&mut self,
keying_material: KeyingMaterialMessage,
) -> Result<Option<KeyingMaterialMessage>, KeyMaterialError> {
let (stream_keys, key_settings) = self.0.as_mut().ok_or(KeyMaterialError::NoKeys)?;
*stream_keys = StreamEncryptionKeys::unwrap_from(key_settings, &keying_material)?;
Ok(Some(keying_material))
}
}
#[derive(Debug)]
pub struct Encryption(Option<EncryptionState>);
#[derive(Debug)]
struct EncryptionState {
key_settings: KeySettings,
key_refresh: KeyMaterialRefreshSettings,
stream_keys: StreamEncryptionKeys,
active_sek: DataEncryption,
packets_until_pre_announcement: usize,
packets_until_transmit: usize,
packets_until_key_switch: usize,
last_key_material: Option<KeyingMaterialMessage>,
}
impl EncryptionState {
fn try_encrypt_packet(&mut self, mut packet: DataPacket) -> Option<(usize, DataPacket)> {
let mut data = BytesMut::with_capacity(packet.payload.len());
data.extend_from_slice(&packet.payload[..]);
let bytes = self
.stream_keys
.encrypt(self.active_sek, packet.seq_number, &mut data)?;
packet.encryption = self.active_sek;
packet.payload = data.freeze();
Some((bytes, packet))
}
fn try_schedule_pre_announcment(&mut self) {
if self.packets_until_pre_announcement == 0 {
self.packets_until_pre_announcement = self.key_refresh.period();
self.packets_until_transmit = 0;
if self.last_key_material.is_none() {
self.last_key_material = self
.stream_keys
.commission_next_key(self.active_sek, &self.key_settings);
}
}
}
fn try_send_key_material(&mut self) -> Option<KeyingMaterialMessage> {
let km = self.last_key_material.as_ref()?;
if self.packets_until_transmit == 0 {
self.packets_until_transmit =
std::cmp::min(self.key_refresh.pre_announcement_period(), 1_000);
Some(km.clone())
} else {
self.packets_until_transmit -= 1;
None
}
}
fn try_switch_stream_keys(&mut self) {
use DataEncryption::*;
if self.packets_until_key_switch == 0 {
self.packets_until_key_switch = self.key_refresh.period();
if self.last_key_material.is_none() {
self.active_sek = match self.active_sek {
Even => Odd,
Odd => Even,
None => None,
};
}
}
}
}
impl Encryption {
pub fn new(settings: Option<CipherSettings>) -> Self {
Self(settings.map(|settings| EncryptionState {
key_settings: settings.key_settings,
key_refresh: settings.key_refresh.clone(),
stream_keys: settings.stream_keys,
active_sek: DataEncryption::Even,
packets_until_pre_announcement: settings.key_refresh.period()
- settings.key_refresh.pre_announcement_period(),
packets_until_transmit: 0,
packets_until_key_switch: settings.key_refresh.period(),
last_key_material: None,
}))
}
pub fn encrypt(
&mut self,
packet: DataPacket,
) -> Option<(usize, DataPacket, Option<KeyingMaterialMessage>)> {
match &mut self.0 {
Some(this) => {
let (bytes, packet) = this.try_encrypt_packet(packet)?;
this.try_schedule_pre_announcment();
this.try_switch_stream_keys();
let km = this.try_send_key_material();
this.packets_until_pre_announcement -= 1;
this.packets_until_key_switch -= 1;
Some((bytes, packet, km))
}
None => Some((0, packet, None)),
}
}
pub fn handle_key_refresh_response(
&mut self,
keying_material: KeyingMaterialMessage,
) -> Result<(), KeyMaterialError> {
use KeyMaterialError::*;
if let Some(settings) = self.0.as_mut() {
let expected_key_material = settings.last_key_material.as_ref().ok_or(NoKeys)?;
if keying_material == *expected_key_material {
settings.packets_until_transmit = 0;
settings.last_key_material = None;
} else {
return Err(InvalidRefreshResponse(keying_material));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key_settings() -> KeySettings {
KeySettings {
key_size: KeySize::AES192,
passphrase: "1234567890".into(),
}
}
fn new_settings() -> CipherSettings {
CipherSettings::new_random(&key_settings(), &Default::default())
}
fn data_packet(encryption: DataEncryption, payload: &str) -> DataPacket {
DataPacket {
seq_number: SeqNumber(3),
message_loc: PacketLocation::ONLY,
in_order_delivery: false,
encryption,
retransmitted: false,
message_number: MsgNumber(1),
timestamp: TimeStamp::MIN,
dest_sockid: SocketId(0),
payload: bytes::Bytes::copy_from_slice(payload.as_bytes()),
}
}
#[test]
fn round_trip() {
let settings = new_settings();
let original_packet = data_packet(DataEncryption::None, "test round_trip");
let mut encryption = Encryption::new(Some(settings.clone()));
let (bytes, encrypted_packet, key_material) =
encryption.encrypt(original_packet.clone()).unwrap();
assert_eq!(bytes, original_packet.payload.len());
assert_ne!(encrypted_packet, original_packet);
assert_eq!(key_material, None);
let decryption = Decryption::new(Some(settings));
let (bytes, decrypted_packet) = decryption.decrypt(encrypted_packet).unwrap();
assert_eq!(bytes, original_packet.payload.len());
assert_eq!(decrypted_packet, original_packet);
}
#[test]
fn decryption_falure() {
use DecryptionError::*;
let with_keys = |with_keys| {
if with_keys {
Decryption::new(Some(new_settings()))
} else {
Decryption::new(None)
}
};
let new_packet = |encryption| data_packet(encryption, "test decryption_falureR");
let packet = new_packet(DataEncryption::None);
assert_eq!(
with_keys(true).decrypt(packet.clone()),
Err(UnexpectedUnencryptedPacket(packet))
);
let packet = new_packet(DataEncryption::Even);
assert_eq!(
with_keys(false).decrypt(packet.clone()),
Err(UnexpectedEncryptedPacket(packet))
);
let packet = new_packet(DataEncryption::Odd);
assert_eq!(
with_keys(false).decrypt(packet.clone()),
Err(UnexpectedEncryptedPacket(packet))
);
let packet = new_packet(DataEncryption::None);
assert_eq!(with_keys(false).decrypt(packet.clone()), Ok((0, packet)));
}
#[test]
fn refresh_key_material() {
let settings = CipherSettings {
key_refresh: KeyMaterialRefreshSettings::new(3_000, 1_000).unwrap(),
..new_settings()
};
let mut encryption = Encryption::new(Some(settings.clone()));
let mut decryption = Decryption::new(Some(settings.clone()));
let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
for i in 0..count {
let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_eq!(km, None);
assert_eq!(packet.encryption, DataEncryption::Even, "{:?}", i);
}
let (_, first_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_ne!(km, None);
assert_eq!(first_packet.encryption, DataEncryption::Even);
let key_material = km.unwrap();
let response = decryption.refresh_key_material(key_material.clone());
assert_eq!(response, Ok(Some(key_material.clone())));
assert_eq!(encryption.handle_key_refresh_response(key_material), Ok(()));
for i in 0..settings.key_refresh.pre_announcement_period() {
let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_eq!(km, None, "{:?}", i);
assert_eq!(packet.encryption, DataEncryption::Even);
}
let (_, second_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_eq!(km, None);
assert_eq!(second_packet.encryption, DataEncryption::Odd);
let (bytes, decrypted_packet) = decryption.decrypt(first_packet).unwrap();
assert_eq!(bytes, original_packet.payload.len());
assert_eq!(decrypted_packet, original_packet);
let (bytes, decrypted_packet) = decryption.decrypt(second_packet).unwrap();
assert_eq!(bytes, original_packet.payload.len());
assert_eq!(decrypted_packet, original_packet);
let count = settings.key_refresh.period() - settings.key_refresh.pre_announcement_period();
for _ in 1..count - 1 {
let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_eq!(km, None);
assert_eq!(packet.encryption, DataEncryption::Odd);
}
let (_, third_packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
assert_ne!(km, None);
assert_eq!(third_packet.encryption, DataEncryption::Odd);
let key_material = km.unwrap();
let response = decryption.refresh_key_material(key_material.clone());
assert_eq!(response, Ok(Some(key_material)));
let (bytes, decrypted_packet) = decryption.decrypt(third_packet).unwrap();
assert_eq!(bytes, original_packet.payload.len());
assert_eq!(decrypted_packet, original_packet);
}
#[test]
fn retry_refresh_key_material() {
let settings = CipherSettings {
key_refresh: KeyMaterialRefreshSettings::new(44_000, 20_000).unwrap(),
..new_settings()
};
let mut encryption = Encryption::new(Some(settings.clone()));
let original_packet = data_packet(DataEncryption::None, "test refresh_key_material");
let mut km_resp = None;
let count = (0..settings.key_refresh.period() - 10_000)
.into_iter()
.filter_map(|_| {
let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
if let Some(km) = &km {
km_resp = Some(km.clone());
}
km.map(|k| (packet.encryption, k))
})
.count();
assert_eq!(count, 10);
encryption
.handle_key_refresh_response(km_resp.unwrap())
.unwrap();
let count = (0..10_000
+ (settings.key_refresh.period() - settings.key_refresh.pre_announcement_period()))
.into_iter()
.filter_map(|_| {
let (_, packet, km) = encryption.encrypt(original_packet.clone()).unwrap();
km.map(|k| (packet.encryption, k))
})
.count();
assert_eq!(count, 0);
}
}