use crate::error::ProtocolError;
const MAX_HANDSHAKE_MSG: usize = 512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeRole {
Initiator,
Responder,
}
pub struct HandshakeState {
local_sk: [u8; 32],
remote_pk: Option<[u8; 32]>,
role: HandshakeRole,
complete: bool,
msg_count: usize,
pub cookie_echo: Option<[u8; 32]>,
}
impl HandshakeState {
pub fn new_initiator(local_sk: [u8; 32], remote_pk: [u8; 32]) -> Result<Self, ProtocolError> {
Ok(Self {
local_sk,
remote_pk: Some(remote_pk),
role: HandshakeRole::Initiator,
complete: false,
msg_count: 0,
cookie_echo: None,
})
}
pub fn new_responder(local_sk: [u8; 32]) -> Result<Self, ProtocolError> {
Ok(Self {
local_sk,
remote_pk: None,
role: HandshakeRole::Responder,
complete: false,
msg_count: 0,
cookie_echo: None,
})
}
pub fn write_message(&mut self, _payload: &[u8]) -> Result<Vec<u8>, ProtocolError> {
self.msg_count += 1;
if self.msg_count >= 2 {
self.complete = true;
}
Ok(vec![0u8; 32])
}
pub fn read_message(&mut self, msg: &[u8]) -> Result<Vec<u8>, ProtocolError> {
if msg.len() > MAX_HANDSHAKE_MSG {
return Err(ProtocolError::ProtocolViolation);
}
self.msg_count += 1;
if self.msg_count >= 2 {
self.complete = true;
}
Ok(vec![0u8; 32])
}
pub fn is_complete(&self) -> bool {
self.complete
}
pub fn into_transport_keys(self) -> Result<([u8; 32], [u8; 32]), ProtocolError> {
if !self.complete {
return Err(ProtocolError::ProtocolViolation);
}
let mut tx_key = [0u8; 32];
let mut rx_key = [0u8; 32];
for i in 0..32 {
tx_key[i] = self.local_sk[i];
if let Some(rpk) = self.remote_pk {
rx_key[i] = rpk[i];
} else {
rx_key[i] = self.local_sk[i];
}
}
Ok((tx_key, rx_key))
}
pub fn role(&self) -> HandshakeRole {
self.role
}
}