use hkdf::Hkdf;
use sha2::{Digest, Sha256};
use uuid::Uuid;
use zeroize::Zeroize;
use crate::error::{EnigmaProtocolError, Result};
use crate::types::{InitiatorOrResponder, SessionBootstrap};
const INFO_SEED: &[u8] = b"enigma-protocol-seed";
const INFO_NEXT: &[u8] = b"enigma-protocol-next";
const INFO_MSG: &[u8] = b"enigma-protocol-msg";
pub fn conversation_ad_bytes(conversation_id: &Uuid, sender: &str, receiver: &str) -> Vec<u8> {
let mut out = Vec::with_capacity(16 + sender.len() + receiver.len());
out.extend_from_slice(conversation_id.as_bytes());
out.extend_from_slice(sender.as_bytes());
out.extend_from_slice(receiver.as_bytes());
out
}
pub fn derive_conversation_id(secret: &[u8; 32], local: &str, remote: &str) -> Uuid {
let mut hasher = Sha256::new();
let (first, second) = if local <= remote {
(local, remote)
} else {
(remote, local)
};
hasher.update(secret);
hasher.update(first.as_bytes());
hasher.update(second.as_bytes());
let digest = hasher.finalize();
let mut bytes = [0u8; 16];
bytes.copy_from_slice(&digest[..16]);
Uuid::from_bytes(bytes)
}
pub struct SessionRatchet {
conversation_id: Uuid,
send_chain: [u8; 32],
recv_chain: [u8; 32],
send_counter: u64,
recv_counter: u64,
}
impl SessionRatchet {
pub fn from_bootstrap(bootstrap: &SessionBootstrap, local: &str, remote: &str) -> Result<Self> {
match bootstrap {
SessionBootstrap::PreSharedSecret {
secret32,
role,
remote_dh_pub,
} => {
let salt = derive_salt(secret32, local, remote, remote_dh_pub.as_ref());
let (send_chain, recv_chain) = derive_initial_chains(secret32, &salt, *role)?;
let conversation_id = derive_conversation_id(secret32, local, remote);
Ok(Self {
conversation_id,
send_chain,
recv_chain,
send_counter: 0,
recv_counter: 0,
})
}
}
}
pub fn conversation_id(&self) -> Uuid {
self.conversation_id
}
pub fn next_send_key(&mut self) -> Result<[u8; 32]> {
let key = next_key(&mut self.send_chain, self.send_counter)?;
self.send_counter = self
.send_counter
.checked_add(1)
.ok_or(EnigmaProtocolError::InvalidState)?;
Ok(key)
}
pub fn next_recv_key(&mut self) -> Result<[u8; 32]> {
let key = next_key(&mut self.recv_chain, self.recv_counter)?;
self.recv_counter = self
.recv_counter
.checked_add(1)
.ok_or(EnigmaProtocolError::InvalidState)?;
Ok(key)
}
}
impl Drop for SessionRatchet {
fn drop(&mut self) {
self.send_chain.zeroize();
self.recv_chain.zeroize();
}
}
fn derive_salt(
secret: &[u8; 32],
local: &str,
remote: &str,
remote_dh: Option<&[u8; 32]>,
) -> [u8; 32] {
let mut hasher = Sha256::new();
let (first, second) = if local <= remote {
(local, remote)
} else {
(remote, local)
};
hasher.update(secret);
hasher.update(first.as_bytes());
hasher.update(second.as_bytes());
if let Some(key) = remote_dh {
hasher.update(key);
}
let digest = hasher.finalize();
let mut salt = [0u8; 32];
salt.copy_from_slice(&digest);
salt
}
fn derive_initial_chains(
seed: &[u8; 32],
salt: &[u8; 32],
role: InitiatorOrResponder,
) -> Result<([u8; 32], [u8; 32])> {
let hk = Hkdf::<Sha256>::new(Some(salt), seed);
let mut okm = [0u8; 64];
hk.expand(INFO_SEED, &mut okm)
.map_err(|_| EnigmaProtocolError::Crypto)?;
let mut first = [0u8; 32];
first.copy_from_slice(&okm[..32]);
let mut second = [0u8; 32];
second.copy_from_slice(&okm[32..]);
match role {
InitiatorOrResponder::Initiator => Ok((first, second)),
InitiatorOrResponder::Responder => Ok((second, first)),
}
}
fn next_key(chain: &mut [u8; 32], counter: u64) -> Result<[u8; 32]> {
let mut counter_bytes = [0u8; 8];
counter_bytes.copy_from_slice(&counter.to_be_bytes());
let hk = Hkdf::<Sha256>::new(Some(chain), &counter_bytes);
let mut message_key = [0u8; 32];
hk.expand(INFO_MSG, &mut message_key)
.map_err(|_| EnigmaProtocolError::Crypto)?;
let mut next_chain = [0u8; 32];
hk.expand(INFO_NEXT, &mut next_chain)
.map_err(|_| EnigmaProtocolError::Crypto)?;
chain.copy_from_slice(&next_chain);
Ok(message_key)
}