use crate::error::{constants, ProtocolError, Result};
use crate::protocol::message::Message;
use crate::utils::replay_cache::ReplayCache;
use rand_core::{OsRng, RngCore};
use sha2::{Digest, Sha256};
use std::time::{SystemTime, UNIX_EPOCH};
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
use zeroize::Zeroize;
#[allow(unused_imports)]
use tracing::{debug, instrument, warn};
#[derive(Zeroize)]
#[zeroize(drop)]
pub struct ClientHandshakeState {
secret: Option<EphemeralSecret>,
public: Option<[u8; 32]>,
server_public: Option<[u8; 32]>,
client_nonce: Option<[u8; 16]>,
server_nonce: Option<[u8; 16]>,
}
impl ClientHandshakeState {
pub fn new() -> Self {
Self {
secret: None,
public: None,
server_public: None,
client_nonce: None,
server_nonce: None,
}
}
#[cfg(test)]
pub fn client_nonce(&self) -> Option<&[u8; 16]> {
self.client_nonce.as_ref()
}
#[cfg(test)]
pub fn server_nonce(&self) -> Option<&[u8; 16]> {
self.server_nonce.as_ref()
}
}
impl Default for ClientHandshakeState {
fn default() -> Self {
Self::new()
}
}
#[derive(Zeroize)]
#[zeroize(drop)]
pub struct ServerHandshakeState {
secret: Option<EphemeralSecret>,
public: Option<[u8; 32]>,
client_public: Option<[u8; 32]>,
client_nonce: Option<[u8; 16]>,
server_nonce: Option<[u8; 16]>,
}
impl ServerHandshakeState {
pub fn new() -> Self {
Self {
secret: None,
public: None,
client_public: None,
client_nonce: None,
server_nonce: None,
}
}
#[cfg(test)]
pub fn server_nonce(&self) -> Option<&[u8; 16]> {
self.server_nonce.as_ref()
}
#[cfg(test)]
pub fn client_public(&self) -> Option<&[u8; 32]> {
self.client_public.as_ref()
}
}
impl Default for ServerHandshakeState {
fn default() -> Self {
Self::new()
}
}
fn current_timestamp() -> Result<u64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis() as u64)
.map_err(|_| ProtocolError::Custom(constants::ERR_SYSTEM_TIME.into()))
}
fn generate_nonce() -> [u8; 16] {
let mut nonce = [0u8; 16];
OsRng.fill_bytes(&mut nonce);
nonce
}
pub fn verify_timestamp(timestamp: u64, max_age_seconds: u64) -> bool {
let current = match current_timestamp() {
Ok(time) => time,
Err(_) => return false,
};
let max_age_ms = max_age_seconds * 1000;
const FUTURE_TOLERANCE_MS: u64 = 2000;
if timestamp > current + FUTURE_TOLERANCE_MS {
return false;
}
if current > timestamp && current - timestamp > max_age_ms {
return false;
}
true
}
fn hash_nonce(nonce: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(nonce);
hasher.finalize().into()
}
fn derive_key_from_shared_secret(
shared_secret: &SharedSecret,
client_nonce: &[u8],
server_nonce: &[u8],
) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(shared_secret.as_bytes());
hasher.update(b"client_nonce");
hasher.update(client_nonce);
hasher.update(b"server_nonce");
hasher.update(server_nonce);
hasher.finalize().into()
}
#[instrument]
pub fn client_secure_handshake_init() -> Result<(ClientHandshakeState, Message)> {
let client_secret = EphemeralSecret::random_from_rng(OsRng);
let client_public = PublicKey::from(&client_secret);
let nonce = generate_nonce();
let timestamp = current_timestamp()?;
let mut state = ClientHandshakeState::new();
state.secret = Some(client_secret);
state.public = Some(client_public.to_bytes());
state.client_nonce = Some(nonce);
debug!("Client initiating secure handshake");
Ok((
state,
Message::SecureHandshakeInit {
pub_key: client_public.to_bytes(),
timestamp,
nonce,
},
))
}
#[instrument(skip(client_pub_key, client_nonce, replay_cache))]
pub fn server_secure_handshake_response(
client_pub_key: [u8; 32],
client_nonce: [u8; 16],
client_timestamp: u64,
peer_id: &str,
replay_cache: &mut ReplayCache,
) -> Result<(ServerHandshakeState, Message)> {
if !verify_timestamp(client_timestamp, 30) {
return Err(ProtocolError::HandshakeError(
constants::ERR_INVALID_TIMESTAMP.into(),
));
}
if replay_cache.is_replay(peer_id, &client_nonce, client_timestamp) {
return Err(ProtocolError::HandshakeError(
constants::ERR_REPLAY_ATTACK.into(),
));
}
let server_secret = EphemeralSecret::random_from_rng(OsRng);
let server_public = PublicKey::from(&server_secret);
let server_nonce = generate_nonce();
let nonce_verification = hash_nonce(&client_nonce);
let mut state = ServerHandshakeState::new();
state.secret = Some(server_secret);
state.public = Some(server_public.to_bytes());
state.client_public = Some(client_pub_key);
state.client_nonce = Some(client_nonce);
state.server_nonce = Some(server_nonce);
debug!("Server responding to handshake initiation");
Ok((
state,
Message::SecureHandshakeResponse {
pub_key: server_public.to_bytes(),
nonce: server_nonce,
nonce_verification,
},
))
}
#[instrument(skip(state, server_pub_key, server_nonce, nonce_verification, replay_cache))]
pub fn client_secure_handshake_verify(
mut state: ClientHandshakeState,
server_pub_key: [u8; 32],
server_nonce: [u8; 16],
nonce_verification: [u8; 32],
peer_id: &str,
replay_cache: &mut ReplayCache,
) -> Result<(ClientHandshakeState, Message)> {
if replay_cache.is_replay(peer_id, &server_nonce, 0) {
return Err(ProtocolError::HandshakeError(
constants::ERR_REPLAY_ATTACK.into(),
));
}
let client_nonce = state.client_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
})?;
let expected_verification = hash_nonce(&client_nonce);
if expected_verification != nonce_verification {
return Err(ProtocolError::HandshakeError(
constants::ERR_NONCE_VERIFICATION_FAILED.into(),
));
}
state.server_public = Some(server_pub_key);
state.server_nonce = Some(server_nonce);
let client_nonce = state.client_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
})?;
let expected_verification = hash_nonce(&client_nonce);
if expected_verification != nonce_verification {
return Err(ProtocolError::HandshakeError(
constants::ERR_NONCE_VERIFICATION_FAILED.into(),
));
}
state.server_public = Some(server_pub_key);
state.server_nonce = Some(server_nonce);
let hash = hash_nonce(&server_nonce);
debug!("Client verified server response");
Ok((
state,
Message::SecureHandshakeConfirm {
nonce_verification: hash,
},
))
}
#[instrument(skip(state, nonce_verification))]
pub fn server_secure_handshake_finalize(
mut state: ServerHandshakeState,
nonce_verification: [u8; 32],
) -> Result<[u8; 32]> {
let server_nonce = state.server_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
})?;
let expected_verification = hash_nonce(&server_nonce);
if expected_verification != nonce_verification {
return Err(ProtocolError::HandshakeError(
constants::ERR_SERVER_VERIFICATION_FAILED.into(),
));
}
let server_secret = state.secret.take().ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_SERVER_SECRET_NOT_FOUND.into())
})?;
let client_public_bytes = state.client_public.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_PUBLIC_NOT_FOUND.into())
})?;
let client_nonce = state.client_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
})?;
let client_public = PublicKey::from(client_public_bytes);
let shared_secret = server_secret.diffie_hellman(&client_public);
let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
debug!("Server finalized handshake and derived session key");
Ok(key)
}
#[instrument(skip(state))]
pub fn client_derive_session_key(mut state: ClientHandshakeState) -> Result<[u8; 32]> {
let client_secret = state.secret.take().ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_SECRET_NOT_FOUND.into())
})?;
let server_public_bytes = state.server_public.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_SERVER_PUBLIC_NOT_FOUND.into())
})?;
let client_nonce = state.client_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
})?;
let server_nonce = state.server_nonce.ok_or_else(|| {
ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
})?;
let server_public = PublicKey::from(server_public_bytes);
let shared_secret = client_secret.diffie_hellman(&server_public);
let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
debug!("Client derived session key");
Ok(key)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_per_session_state_isolation() {
let mut replay_cache = crate::utils::replay_cache::ReplayCache::new();
let peer_id = "test-peer";
let (client1, msg1) = client_secure_handshake_init().unwrap();
let (client2, msg2) = client_secure_handshake_init().unwrap();
let (pub_key1, ts1, nonce1) = match msg1 {
Message::SecureHandshakeInit {
pub_key,
timestamp,
nonce,
} => (pub_key, timestamp, nonce),
_ => panic!("Wrong message type"),
};
let (pub_key2, ts2, nonce2) = match msg2 {
Message::SecureHandshakeInit {
pub_key,
timestamp,
nonce,
} => (pub_key, timestamp, nonce),
_ => panic!("Wrong message type"),
};
assert_ne!(pub_key1, pub_key2);
assert_ne!(nonce1, nonce2);
let (server1, resp1) =
server_secure_handshake_response(pub_key1, nonce1, ts1, peer_id, &mut replay_cache)
.unwrap();
let (server2, resp2) =
server_secure_handshake_response(pub_key2, nonce2, ts2, peer_id, &mut replay_cache)
.unwrap();
let (server_pub1, server_nonce1, verify1) = match resp1 {
Message::SecureHandshakeResponse {
pub_key,
nonce,
nonce_verification,
} => (pub_key, nonce, nonce_verification),
_ => panic!("Wrong message type"),
};
let (server_pub2, server_nonce2, verify2) = match resp2 {
Message::SecureHandshakeResponse {
pub_key,
nonce,
nonce_verification,
} => (pub_key, nonce, nonce_verification),
_ => panic!("Wrong message type"),
};
assert_ne!(server_pub1, server_pub2);
assert_ne!(server_nonce1, server_nonce2);
let (client1_verified, confirm1) = client_secure_handshake_verify(
client1,
server_pub1,
server_nonce1,
verify1,
peer_id,
&mut replay_cache,
)
.unwrap();
let (client2_verified, confirm2) = client_secure_handshake_verify(
client2,
server_pub2,
server_nonce2,
verify2,
peer_id,
&mut replay_cache,
)
.unwrap();
let confirm_hash1 = match confirm1 {
Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
_ => panic!("Wrong message type"),
};
let confirm_hash2 = match confirm2 {
Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
_ => panic!("Wrong message type"),
};
assert_ne!(confirm_hash1, confirm_hash2);
let key1_server = server_secure_handshake_finalize(server1, confirm_hash1).unwrap();
let key1_client = client_derive_session_key(client1_verified).unwrap();
let key2_server = server_secure_handshake_finalize(server2, confirm_hash2).unwrap();
let key2_client = client_derive_session_key(client2_verified).unwrap();
assert_eq!(key1_server, key1_client);
assert_eq!(key2_server, key2_client);
assert_ne!(key1_server, key2_server);
}
#[test]
fn test_timestamp_validation() {
let now = current_timestamp().unwrap();
assert!(verify_timestamp(now, 30));
assert!(verify_timestamp(now - 10000, 30)); assert!(!verify_timestamp(now - 31000, 30)); assert!(verify_timestamp(now + 1000, 30)); assert!(!verify_timestamp(now + 3000, 30)); }
#[test]
fn test_nonce_verification() {
let nonce = generate_nonce();
let hash = hash_nonce(&nonce);
assert_eq!(hash.len(), 32);
assert_eq!(hash, hash_nonce(&nonce));
let mut different_nonce = nonce;
different_nonce[0] ^= 0xFF;
assert_ne!(hash, hash_nonce(&different_nonce));
}
}