use po_crypto::aead::SessionCipher;
use po_crypto::exchange::EphemeralKeypair;
use po_crypto::identity::Identity;
use po_transport::traits::AsyncFrameTransport;
use po_wire::{FrameHeader, FrameType};
use crate::framer::{Framer, FramerError};
use crate::message::{HandshakeComplete, HandshakeInit, HandshakeReply};
use ed25519_dalek::{Signature, VerifyingKey};
use std::time::{SystemTime, UNIX_EPOCH};
pub struct HandshakeResult {
pub cipher: SessionCipher,
pub peer_pubkey: [u8; 32],
pub peer_node_id: po_crypto::identity::NodeId,
}
pub async fn perform_handshake_initiator(
identity: &Identity,
transport: &mut dyn AsyncFrameTransport,
framer: &mut Framer,
) -> Result<HandshakeResult, HandshakeError> {
let ephemeral = EphemeralKeypair::generate();
let our_eph_pub = ephemeral.public_bytes();
let timestamp = now_millis();
let mut sign_data = Vec::with_capacity(41);
sign_data.push(1u8); sign_data.extend_from_slice(&our_eph_pub);
sign_data.extend_from_slice(×tamp.to_le_bytes());
let signature = identity.sign(&sign_data);
let init = HandshakeInit {
version: 1,
ed25519_pubkey: identity.public_key_bytes(),
x25519_ephemeral: our_eph_pub,
timestamp,
signature: signature.to_bytes().to_vec(),
};
let payload =
bincode::serialize(&init).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
let header = FrameHeader {
frame_type: FrameType::HandshakeInit,
flags: po_wire::FrameFlags::default(),
channel_id: 0,
stream_id: 0,
payload_len: payload.len() as u64,
};
framer
.write_frame(transport, &header, &payload)
.await
.map_err(HandshakeError::Framer)?;
let (reply_header, reply_payload) = framer
.read_frame(transport)
.await
.map_err(HandshakeError::Framer)?
.ok_or(HandshakeError::ConnectionClosed)?;
if reply_header.frame_type != FrameType::HandshakeReply {
return Err(HandshakeError::UnexpectedFrame(reply_header.frame_type));
}
let reply: HandshakeReply = bincode::deserialize(&reply_payload)
.map_err(|e| HandshakeError::Serialization(e.to_string()))?;
let peer_verifying =
VerifyingKey::from_bytes(&reply.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
let mut verify_data = Vec::with_capacity(64);
verify_data.extend_from_slice(&our_eph_pub);
verify_data.extend_from_slice(&reply.x25519_ephemeral);
let peer_sig = Signature::from_bytes(
reply
.signature
.as_slice()
.try_into()
.map_err(|_| HandshakeError::InvalidSignature)?,
);
if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
return Err(HandshakeError::InvalidSignature);
}
let context = build_session_context(&identity.public_key_bytes(), &reply.ed25519_pubkey);
let session_key = ephemeral
.derive_session_key(&reply.x25519_ephemeral, &context)
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
let mut cipher = SessionCipher::new(session_key.as_bytes());
let confirmation = cipher
.encrypt(b"PO_READY", b"handshake-complete")
.map_err(|e| HandshakeError::Encryption(e.to_string()))?;
let complete = HandshakeComplete { confirmation };
let complete_payload =
bincode::serialize(&complete).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
let complete_header = FrameHeader {
frame_type: FrameType::HandshakeComplete,
flags: po_wire::FrameFlags::default(),
channel_id: 0,
stream_id: 0,
payload_len: complete_payload.len() as u64,
};
framer
.write_frame(transport, &complete_header, &complete_payload)
.await
.map_err(HandshakeError::Framer)?;
let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
Ok(HandshakeResult {
cipher,
peer_pubkey: reply.ed25519_pubkey,
peer_node_id,
})
}
pub async fn perform_handshake_responder(
identity: &Identity,
transport: &mut dyn AsyncFrameTransport,
framer: &mut Framer,
) -> Result<HandshakeResult, HandshakeError> {
let (init_header, init_payload) = framer
.read_frame(transport)
.await
.map_err(HandshakeError::Framer)?
.ok_or(HandshakeError::ConnectionClosed)?;
if init_header.frame_type != FrameType::HandshakeInit {
return Err(HandshakeError::UnexpectedFrame(init_header.frame_type));
}
let init: HandshakeInit = bincode::deserialize(&init_payload)
.map_err(|e| HandshakeError::Serialization(e.to_string()))?;
if init.version != 1 {
return Err(HandshakeError::UnsupportedVersion(init.version));
}
let peer_verifying =
VerifyingKey::from_bytes(&init.ed25519_pubkey).map_err(|_| HandshakeError::InvalidKey)?;
let mut verify_data = Vec::with_capacity(41);
verify_data.push(init.version);
verify_data.extend_from_slice(&init.x25519_ephemeral);
verify_data.extend_from_slice(&init.timestamp.to_le_bytes());
let peer_sig = Signature::from_bytes(
init.signature
.as_slice()
.try_into()
.map_err(|_| HandshakeError::InvalidSignature)?,
);
if !Identity::verify(&peer_verifying, &verify_data, &peer_sig) {
return Err(HandshakeError::InvalidSignature);
}
let now = now_millis();
let drift = now.abs_diff(init.timestamp);
if drift > 30_000 {
return Err(HandshakeError::TimestampExpired);
}
let ephemeral = EphemeralKeypair::generate();
let our_eph_pub = ephemeral.public_bytes();
let mut sign_data = Vec::with_capacity(64);
sign_data.extend_from_slice(&init.x25519_ephemeral);
sign_data.extend_from_slice(&our_eph_pub);
let signature = identity.sign(&sign_data);
let reply = HandshakeReply {
ed25519_pubkey: identity.public_key_bytes(),
x25519_ephemeral: our_eph_pub,
signature: signature.to_bytes().to_vec(),
};
let payload =
bincode::serialize(&reply).map_err(|e| HandshakeError::Serialization(e.to_string()))?;
let header = FrameHeader {
frame_type: FrameType::HandshakeReply,
flags: po_wire::FrameFlags::default(),
channel_id: 0,
stream_id: 0,
payload_len: payload.len() as u64,
};
framer
.write_frame(transport, &header, &payload)
.await
.map_err(HandshakeError::Framer)?;
let context = build_session_context(&init.ed25519_pubkey, &identity.public_key_bytes());
let session_key = ephemeral
.derive_session_key(&init.x25519_ephemeral, &context)
.map_err(|e| HandshakeError::KeyDerivation(e.to_string()))?;
let cipher = SessionCipher::new(session_key.as_bytes());
let (complete_header, complete_payload) = framer
.read_frame(transport)
.await
.map_err(HandshakeError::Framer)?
.ok_or(HandshakeError::ConnectionClosed)?;
if complete_header.frame_type != FrameType::HandshakeComplete {
return Err(HandshakeError::UnexpectedFrame(complete_header.frame_type));
}
let complete: HandshakeComplete = bincode::deserialize(&complete_payload)
.map_err(|e| HandshakeError::Serialization(e.to_string()))?;
let decrypted = cipher
.decrypt(&complete.confirmation, b"handshake-complete")
.map_err(|_| HandshakeError::ConfirmationFailed)?;
if decrypted != b"PO_READY" {
return Err(HandshakeError::ConfirmationFailed);
}
let peer_node_id = po_crypto::identity::NodeId::from_public_key(&peer_verifying);
Ok(HandshakeResult {
cipher,
peer_pubkey: init.ed25519_pubkey,
peer_node_id,
})
}
fn build_session_context(initiator_pubkey: &[u8; 32], responder_pubkey: &[u8; 32]) -> Vec<u8> {
let mut ctx = Vec::with_capacity(64 + 10);
ctx.extend_from_slice(b"po-v1-");
ctx.extend_from_slice(initiator_pubkey);
ctx.extend_from_slice(responder_pubkey);
ctx
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[derive(Debug)]
pub enum HandshakeError {
Framer(FramerError),
Serialization(String),
InvalidSignature,
InvalidKey,
UnsupportedVersion(u8),
TimestampExpired,
KeyDerivation(String),
Encryption(String),
ConfirmationFailed,
ConnectionClosed,
UnexpectedFrame(FrameType),
}
impl std::fmt::Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Framer(e) => write!(f, "framer: {e}"),
Self::Serialization(e) => write!(f, "serialization: {e}"),
Self::InvalidSignature => write!(f, "invalid signature"),
Self::InvalidKey => write!(f, "invalid public key"),
Self::UnsupportedVersion(v) => write!(f, "unsupported protocol version: {v}"),
Self::TimestampExpired => write!(f, "handshake timestamp expired"),
Self::KeyDerivation(e) => write!(f, "key derivation: {e}"),
Self::Encryption(e) => write!(f, "encryption: {e}"),
Self::ConfirmationFailed => write!(f, "handshake confirmation failed"),
Self::ConnectionClosed => write!(f, "connection closed during handshake"),
Self::UnexpectedFrame(t) => write!(f, "unexpected frame type: {t}"),
}
}
}
impl std::error::Error for HandshakeError {}