use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tracing::{debug, error, info};
use uuid::Uuid;
use qudag_crypto::{
Ciphertext as KemCiphertext, KeyPair as KemKeyPair, MlDsaKeyPair, MlDsaPublicKey, MlKem768,
PublicKey as KemPublicKey, SecretKey, SharedSecret,
};
use rand;
use crate::message::{HandshakeType, Message, MessageError, MessageType, ProtocolVersion};
use crate::state::{ProtocolStateMachine, StateError};
#[derive(Debug, Error)]
pub enum HandshakeError {
#[error("Cryptographic operation failed: {reason}")]
CryptoError { reason: String },
#[error("Invalid handshake message: {reason}")]
InvalidMessage { reason: String },
#[error("Handshake timed out after {timeout:?}")]
Timeout { timeout: Duration },
#[error("Protocol version mismatch: expected {expected:?}, got {actual:?}")]
VersionMismatch {
expected: ProtocolVersion,
actual: ProtocolVersion,
},
#[error("Unsupported capabilities: {capabilities:?}")]
UnsupportedCapabilities { capabilities: Vec<String> },
#[error("Invalid peer credentials")]
InvalidCredentials,
#[error("State machine error: {0}")]
StateMachine(#[from] StateError),
#[error("Message error: {0}")]
Message(#[from] MessageError),
#[error("Handshake already in progress with session {session_id}")]
HandshakeInProgress { session_id: Uuid },
#[error("Replay attack detected: timestamp {timestamp} is too old")]
ReplayAttack { timestamp: u64 },
}
#[derive(Debug, Clone)]
pub struct HandshakeConfig {
pub timeout: Duration,
pub supported_versions: Vec<ProtocolVersion>,
pub required_capabilities: Vec<String>,
pub optional_capabilities: Vec<String>,
pub max_timestamp_skew: Duration,
pub mutual_auth: bool,
}
impl Default for HandshakeConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
supported_versions: vec![ProtocolVersion::CURRENT],
required_capabilities: vec![
"dag-consensus".to_string(),
"quantum-resistant-crypto".to_string(),
],
optional_capabilities: vec![
"anonymous-routing".to_string(),
"dark-addressing".to_string(),
],
max_timestamp_skew: Duration::from_secs(300), mutual_auth: true,
}
}
}
#[derive(Debug)]
pub struct HandshakeSession {
pub session_id: Uuid,
pub peer_id: Option<Vec<u8>>,
pub state: HandshakeSessionState,
pub negotiated_version: Option<ProtocolVersion>,
pub peer_capabilities: Vec<String>,
pub our_keys: HandshakeKeys,
pub peer_keys: Option<PeerKeys>,
pub shared_secrets: Option<SharedSecrets>,
pub started_at: SystemTime,
pub last_activity: SystemTime,
pub nonce: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HandshakeSessionState {
Waiting,
InitSent,
InitReceived,
ResponseSent,
ResponseReceived,
Completed,
Failed,
}
#[derive(Debug)]
pub struct HandshakeKeys {
pub signature_keypair: MlDsaKeyPair,
pub kem_keypair: KemKeyPair,
}
#[derive(Debug, Clone)]
pub struct PeerKeys {
pub signature_public_key: MlDsaPublicKey,
pub kem_public_key: KemPublicKey,
}
#[derive(Debug, Clone)]
pub struct SharedSecrets {
pub kem_shared_secret: SharedSecret,
pub encryption_key: Vec<u8>,
pub mac_key: Vec<u8>,
pub session_key: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HandshakeMessagePayload {
Init {
protocol_version: ProtocolVersion,
supported_versions: Vec<ProtocolVersion>,
capabilities: Vec<String>,
signature_public_key: Vec<u8>,
kem_public_key: Vec<u8>,
nonce: u64,
timestamp: u64,
},
Response {
protocol_version: ProtocolVersion,
capabilities: Vec<String>,
signature_public_key: Vec<u8>,
kem_ciphertext: Vec<u8>, nonce: u64,
timestamp: u64,
},
Complete { session_id: Vec<u8>, timestamp: u64 },
VersionNegotiation {
supported_versions: Vec<ProtocolVersion>,
preferred_version: ProtocolVersion,
},
}
pub struct HandshakeCoordinator {
config: HandshakeConfig,
sessions: HashMap<Uuid, HandshakeSession>,
#[allow(dead_code)]
identity_keys: HandshakeKeys,
state_machine: ProtocolStateMachine,
}
impl HandshakeCoordinator {
pub fn new(
config: HandshakeConfig,
identity_keys: HandshakeKeys,
state_machine: ProtocolStateMachine,
) -> Self {
Self {
config,
sessions: HashMap::new(),
identity_keys,
state_machine,
}
}
pub fn generate_keys() -> Result<HandshakeKeys, HandshakeError> {
let signature_keypair = MlDsaKeyPair::generate(&mut rand::thread_rng()).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("Failed to generate ML-DSA keypair: {:?}", e),
}
})?;
let (kem_public_key, kem_secret_key) =
MlKem768::keygen().map_err(|e| HandshakeError::CryptoError {
reason: format!("Failed to generate ML-KEM keypair: {:?}", e),
})?;
let kem_keypair = KemKeyPair {
public_key: kem_public_key.as_bytes().to_vec(),
secret_key: kem_secret_key.as_bytes().to_vec(),
};
Ok(HandshakeKeys {
signature_keypair,
kem_keypair,
})
}
pub fn initiate_handshake(
&mut self,
peer_id: Option<Vec<u8>>,
) -> Result<(Uuid, Message), HandshakeError> {
let session_keys = Self::generate_keys()?;
let session_id = Uuid::new_v4();
let nonce = rand::random::<u64>();
let session = HandshakeSession {
session_id,
peer_id,
state: HandshakeSessionState::Waiting,
negotiated_version: None,
peer_capabilities: Vec::new(),
our_keys: session_keys,
peer_keys: None,
shared_secrets: None,
started_at: SystemTime::now(),
last_activity: SystemTime::now(),
nonce,
};
let payload = HandshakeMessagePayload::Init {
protocol_version: ProtocolVersion::CURRENT,
supported_versions: self.config.supported_versions.clone(),
capabilities: [
self.config.required_capabilities.clone(),
self.config.optional_capabilities.clone(),
]
.concat(),
signature_public_key: session.our_keys.signature_keypair.public_key().to_vec(),
kem_public_key: session.our_keys.kem_keypair.public_key().to_vec(),
nonce,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let payload_bytes =
bincode::serialize(&payload).map_err(|e| HandshakeError::InvalidMessage {
reason: format!("Failed to serialize handshake init: {:?}", e),
})?;
let mut message = Message::new(MessageType::Handshake(HandshakeType::Init), payload_bytes);
message.sign(&session.our_keys.signature_keypair)?;
let mut session = session;
session.state = HandshakeSessionState::InitSent;
session.last_activity = SystemTime::now();
self.sessions.insert(session_id, session);
info!("Initiated handshake session: {}", session_id);
Ok((session_id, message))
}
pub fn process_handshake_message(
&mut self,
message: &Message,
session_id: Option<Uuid>,
) -> Result<Option<Message>, HandshakeError> {
self.validate_timestamp(message)?;
match &message.msg_type {
MessageType::Handshake(HandshakeType::Init) => self.process_handshake_init(message),
MessageType::Handshake(HandshakeType::Response) => {
self.process_handshake_response(message, session_id)
}
MessageType::Handshake(HandshakeType::Complete) => {
self.process_handshake_complete(message, session_id)
}
MessageType::Handshake(HandshakeType::VersionNegotiation) => {
self.process_version_negotiation(message, session_id)
}
_ => Err(HandshakeError::InvalidMessage {
reason: "Not a handshake message".to_string(),
}),
}
}
fn process_handshake_init(
&mut self,
message: &Message,
) -> Result<Option<Message>, HandshakeError> {
let payload: HandshakeMessagePayload =
bincode::deserialize(&message.payload).map_err(|e| HandshakeError::InvalidMessage {
reason: format!("Failed to deserialize handshake init: {:?}", e),
})?;
if let HandshakeMessagePayload::Init {
protocol_version,
supported_versions,
capabilities,
signature_public_key,
kem_public_key,
nonce,
timestamp: _,
} = payload
{
let negotiated_version =
self.negotiate_version(&supported_versions, &protocol_version)?;
self.verify_capabilities(&capabilities)?;
let peer_signature_key =
MlDsaPublicKey::from_bytes(&signature_public_key).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("Invalid peer signature key: {:?}", e),
}
})?;
let peer_kem_key = KemPublicKey::from_bytes(&kem_public_key).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("Invalid peer KEM key: {:?}", e),
}
})?;
if !message.verify(&peer_signature_key)? {
return Err(HandshakeError::InvalidCredentials);
}
let session_keys = Self::generate_keys()?;
let session_id = Uuid::new_v4();
let our_nonce = rand::random::<u64>();
let (kem_ciphertext, shared_secret) =
MlKem768::encapsulate(&peer_kem_key).map_err(|e| HandshakeError::CryptoError {
reason: format!("KEM encapsulation failed: {:?}", e),
})?;
let shared_secrets = self.derive_session_keys(&shared_secret, nonce, our_nonce)?;
let session = HandshakeSession {
session_id,
peer_id: None, state: HandshakeSessionState::InitReceived,
negotiated_version: Some(negotiated_version.clone()),
peer_capabilities: capabilities,
our_keys: session_keys,
peer_keys: Some(PeerKeys {
signature_public_key: peer_signature_key,
kem_public_key: peer_kem_key,
}),
shared_secrets: Some(shared_secrets),
started_at: SystemTime::now(),
last_activity: SystemTime::now(),
nonce: our_nonce,
};
let response_payload = HandshakeMessagePayload::Response {
protocol_version: negotiated_version,
capabilities: [
self.config.required_capabilities.clone(),
self.config.optional_capabilities.clone(),
]
.concat(),
signature_public_key: session.our_keys.signature_keypair.public_key().to_vec(),
kem_ciphertext: kem_ciphertext.as_bytes().to_vec(),
nonce: our_nonce,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let response_bytes = bincode::serialize(&response_payload).map_err(|e| {
HandshakeError::InvalidMessage {
reason: format!("Failed to serialize handshake response: {:?}", e),
}
})?;
let mut response_message = Message::new(
MessageType::Handshake(HandshakeType::Response),
response_bytes,
);
response_message.sign(&session.our_keys.signature_keypair)?;
let mut session = session;
session.state = HandshakeSessionState::ResponseSent;
session.last_activity = SystemTime::now();
self.sessions.insert(session_id, session);
self.state_machine
.process_message(message, Some(session_id))?;
info!(
"Processed handshake init, sending response for session: {}",
session_id
);
Ok(Some(response_message))
} else {
Err(HandshakeError::InvalidMessage {
reason: "Expected handshake init payload".to_string(),
})
}
}
fn process_handshake_response(
&mut self,
message: &Message,
session_id: Option<Uuid>,
) -> Result<Option<Message>, HandshakeError> {
let session_id = session_id.ok_or(HandshakeError::InvalidMessage {
reason: "Session ID required for handshake response".to_string(),
})?;
let session = self
.sessions
.get_mut(&session_id)
.ok_or(HandshakeError::InvalidMessage {
reason: format!("Session not found: {}", session_id),
})?;
if session.state != HandshakeSessionState::InitSent {
return Err(HandshakeError::InvalidMessage {
reason: format!("Invalid session state for response: {:?}", session.state),
});
}
let payload: HandshakeMessagePayload =
bincode::deserialize(&message.payload).map_err(|e| HandshakeError::InvalidMessage {
reason: format!("Failed to deserialize handshake response: {:?}", e),
})?;
if let HandshakeMessagePayload::Response {
protocol_version,
capabilities,
signature_public_key,
kem_ciphertext,
nonce,
timestamp: _,
} = payload
{
let peer_signature_key =
MlDsaPublicKey::from_bytes(&signature_public_key).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("Invalid peer signature key: {:?}", e),
}
})?;
if !message.verify(&peer_signature_key)? {
return Err(HandshakeError::InvalidCredentials);
}
if !self.config.supported_versions.contains(&protocol_version) {
return Err(HandshakeError::VersionMismatch {
expected: ProtocolVersion::CURRENT,
actual: protocol_version,
});
}
{
for required_cap in &self.config.required_capabilities {
if !capabilities.contains(required_cap) {
return Err(HandshakeError::UnsupportedCapabilities {
capabilities: vec![required_cap.clone()],
});
}
}
}
let kem_ciphertext_bytes = KemCiphertext::from_bytes(&kem_ciphertext).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("Invalid KEM ciphertext: {:?}", e),
}
})?;
let secret_key = SecretKey::from_bytes(session.our_keys.kem_keypair.secret_key())
.map_err(|e| HandshakeError::CryptoError {
reason: format!("Invalid secret key: {:?}", e),
})?;
let shared_secret =
MlKem768::decapsulate(&secret_key, &kem_ciphertext_bytes).map_err(|e| {
HandshakeError::CryptoError {
reason: format!("KEM decapsulation failed: {:?}", e),
}
})?;
let session_nonce = session.nonce;
let shared_secrets = {
let secret_bytes = shared_secret.as_bytes();
let combined_nonce = session_nonce ^ nonce;
let nonce_bytes = combined_nonce.to_be_bytes();
let mut key_material = Vec::new();
key_material.extend_from_slice(secret_bytes);
key_material.extend_from_slice(&nonce_bytes);
let encryption_key = blake3::keyed_hash(
blake3::hash(b"QuDAG-Encryption-Key").as_bytes(),
&key_material,
)
.as_bytes()
.to_vec();
let mac_key =
blake3::keyed_hash(blake3::hash(b"QuDAG-MAC-Key").as_bytes(), &key_material)
.as_bytes()
.to_vec();
let session_key = blake3::keyed_hash(
blake3::hash(b"QuDAG-Session-Key").as_bytes(),
&key_material,
)
.as_bytes()
.to_vec();
SharedSecrets {
kem_shared_secret: shared_secret.clone(),
encryption_key,
mac_key,
session_key,
}
};
session.negotiated_version = Some(protocol_version);
session.peer_capabilities = capabilities;
session.peer_keys = Some(PeerKeys {
signature_public_key: peer_signature_key,
kem_public_key: KemPublicKey::from_bytes(session.our_keys.kem_keypair.public_key())
.map_err(|e| HandshakeError::CryptoError {
reason: format!("Invalid public key: {:?}", e),
})?, });
session.shared_secrets = Some(shared_secrets);
session.state = HandshakeSessionState::ResponseReceived;
session.last_activity = SystemTime::now();
let complete_payload = HandshakeMessagePayload::Complete {
session_id: session_id.as_bytes().to_vec(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let complete_bytes = bincode::serialize(&complete_payload).map_err(|e| {
HandshakeError::InvalidMessage {
reason: format!("Failed to serialize handshake complete: {:?}", e),
}
})?;
let mut complete_message = Message::new(
MessageType::Handshake(HandshakeType::Complete),
complete_bytes,
);
complete_message.sign(&session.our_keys.signature_keypair)?;
session.state = HandshakeSessionState::Completed;
self.state_machine
.process_message(message, Some(session_id))?;
info!(
"Processed handshake response, sending completion for session: {}",
session_id
);
Ok(Some(complete_message))
} else {
Err(HandshakeError::InvalidMessage {
reason: "Expected handshake response payload".to_string(),
})
}
}
fn process_handshake_complete(
&mut self,
message: &Message,
session_id: Option<Uuid>,
) -> Result<Option<Message>, HandshakeError> {
let session_id = session_id.ok_or(HandshakeError::InvalidMessage {
reason: "Session ID required for handshake complete".to_string(),
})?;
let session = self
.sessions
.get_mut(&session_id)
.ok_or(HandshakeError::InvalidMessage {
reason: format!("Session not found: {}", session_id),
})?;
if session.state != HandshakeSessionState::ResponseSent {
return Err(HandshakeError::InvalidMessage {
reason: format!("Invalid session state for complete: {:?}", session.state),
});
}
if let Some(peer_keys) = &session.peer_keys {
if !message.verify(&peer_keys.signature_public_key)? {
return Err(HandshakeError::InvalidCredentials);
}
}
session.state = HandshakeSessionState::Completed;
session.last_activity = SystemTime::now();
self.state_machine
.process_message(message, Some(session_id))?;
info!(
"Handshake completed successfully for session: {}",
session_id
);
Ok(None) }
fn process_version_negotiation(
&mut self,
message: &Message,
_session_id: Option<Uuid>,
) -> Result<Option<Message>, HandshakeError> {
let payload: HandshakeMessagePayload =
bincode::deserialize(&message.payload).map_err(|e| HandshakeError::InvalidMessage {
reason: format!("Failed to deserialize version negotiation: {:?}", e),
})?;
if let HandshakeMessagePayload::VersionNegotiation {
supported_versions,
preferred_version,
} = payload
{
let compatible_version =
self.negotiate_version(&supported_versions, &preferred_version)?;
debug!("Negotiated protocol version: {:?}", compatible_version);
Ok(None)
} else {
Err(HandshakeError::InvalidMessage {
reason: "Expected version negotiation payload".to_string(),
})
}
}
fn negotiate_version(
&self,
peer_versions: &[ProtocolVersion],
peer_preferred: &ProtocolVersion,
) -> Result<ProtocolVersion, HandshakeError> {
if self.config.supported_versions.contains(peer_preferred) {
return Ok(peer_preferred.clone());
}
for our_version in &self.config.supported_versions {
for peer_version in peer_versions {
if our_version.is_compatible(peer_version) {
return Ok(our_version.clone());
}
}
}
Err(HandshakeError::VersionMismatch {
expected: ProtocolVersion::CURRENT,
actual: peer_preferred.clone(),
})
}
fn verify_capabilities(&self, peer_capabilities: &[String]) -> Result<(), HandshakeError> {
for required_cap in &self.config.required_capabilities {
if !peer_capabilities.contains(required_cap) {
return Err(HandshakeError::UnsupportedCapabilities {
capabilities: vec![required_cap.clone()],
});
}
}
Ok(())
}
fn derive_session_keys(
&self,
shared_secret: &SharedSecret,
our_nonce: u64,
peer_nonce: u64,
) -> Result<SharedSecrets, HandshakeError> {
let secret_bytes = shared_secret.as_bytes();
let combined_nonce = our_nonce ^ peer_nonce;
let nonce_bytes = combined_nonce.to_be_bytes();
let mut key_material = Vec::new();
key_material.extend_from_slice(secret_bytes);
key_material.extend_from_slice(&nonce_bytes);
let encryption_key = blake3::keyed_hash(
blake3::hash(b"QuDAG-Encryption-Key").as_bytes(),
&key_material,
)
.as_bytes()
.to_vec();
let mac_key = blake3::keyed_hash(blake3::hash(b"QuDAG-MAC-Key").as_bytes(), &key_material)
.as_bytes()
.to_vec();
let session_key =
blake3::keyed_hash(blake3::hash(b"QuDAG-Session-Key").as_bytes(), &key_material)
.as_bytes()
.to_vec();
Ok(SharedSecrets {
kem_shared_secret: shared_secret.clone(),
encryption_key,
mac_key,
session_key,
})
}
fn validate_timestamp(&self, message: &Message) -> Result<(), HandshakeError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let message_time = message.timestamp;
let max_skew = self.config.max_timestamp_skew.as_millis() as u64;
if now > message_time + max_skew || message_time > now + max_skew {
return Err(HandshakeError::ReplayAttack {
timestamp: message_time,
});
}
Ok(())
}
pub fn get_session(&self, session_id: &Uuid) -> Option<&HandshakeSession> {
self.sessions.get(session_id)
}
pub fn cleanup_sessions(&mut self) {
let now = SystemTime::now();
let timeout = self.config.timeout;
self.sessions.retain(|_, session| {
let elapsed = now
.duration_since(session.started_at)
.unwrap_or(Duration::ZERO);
match session.state {
HandshakeSessionState::Completed | HandshakeSessionState::Failed => false,
_ => elapsed < timeout,
}
});
}
pub fn get_active_sessions(&self) -> Vec<&HandshakeSession> {
self.sessions
.values()
.filter(|s| {
!matches!(
s.state,
HandshakeSessionState::Completed | HandshakeSessionState::Failed
)
})
.collect()
}
pub fn is_handshake_completed(&self, session_id: &Uuid) -> bool {
self.sessions
.get(session_id)
.map(|s| s.state == HandshakeSessionState::Completed)
.unwrap_or(false)
}
pub fn get_shared_secrets(&self, session_id: &Uuid) -> Option<&SharedSecrets> {
self.sessions
.get(session_id)
.and_then(|s| s.shared_secrets.as_ref())
}
}