use std::time::{SystemTime, UNIX_EPOCH};
use ed25519_dalek::{Signature, Signer, Verifier};
use rand::Rng;
use sha2::{Digest, Sha256};
pub use ed25519_dalek::{SigningKey, VerifyingKey};
use crate::crypto::kem::{EncapsulatedKey, HybridKem, HybridKeypair, HybridPublicKey};
use crate::error::Result;
use crate::error::{SessionError, SrxError};
const MAGIC: &[u8; 4] = b"SRXH";
const VERSION: u8 = 1;
const MSG_CLIENT_HELLO: u8 = 1;
const MSG_SERVER_HELLO: u8 = 2;
const MSG_CLIENT_FINISHED: u8 = 3;
const NONCE_LEN: usize = 16;
const ED25519_SIG_LEN: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeRole {
Initiator,
Responder,
}
pub struct Handshake {
role: HandshakeRole,
state: HandshakeState,
server_kp: Option<HybridKeypair>,
master_secret: Option<[u8; 32]>,
server_identity: Option<SigningKey>,
trusted_server: Option<VerifyingKey>,
client_hello_wire: Option<Vec<u8>>,
initiator_client_hello: Option<Vec<u8>>,
}
#[derive(Debug)]
enum HandshakeState {
Initial,
AwaitingServerHello,
AwaitingClientFinished,
Completed,
}
fn transcript_digest(client_hello_wire: &[u8], hybrid_pk_bytes: &[u8]) -> [u8; 32] {
let mut h = Sha256::new();
h.update(client_hello_wire);
h.update(hybrid_pk_bytes);
h.finalize().into()
}
fn parse_server_hello_payload(
data: &[u8],
expect_sig: bool,
) -> Result<(HybridPublicKey, Option<Signature>)> {
if data.len() < 2 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"ServerHello payload too short".into(),
)));
}
let kyber_len = u16::from_be_bytes([data[0], data[1]]) as usize;
let pk_len = 2 + kyber_len + 32;
if data.len() < pk_len {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"ServerHello hybrid key truncated".into(),
)));
}
if data.len() == pk_len {
if expect_sig {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"missing Ed25519 server signature".into(),
)));
}
let pk = HybridPublicKey::from_bytes(&data[..pk_len])?;
return Ok((pk, None));
}
if data.len() == pk_len + ED25519_SIG_LEN {
let pk = HybridPublicKey::from_bytes(&data[..pk_len])?;
let sig = Signature::try_from(&data[pk_len..]).map_err(|_| {
SrxError::Session(SessionError::HandshakeFailed(
"invalid Ed25519 signature bytes".into(),
))
})?;
return Ok((pk, Some(sig)));
}
Err(SrxError::Session(SessionError::HandshakeFailed(
"bad ServerHello payload length".into(),
)))
}
impl Handshake {
pub fn new_initiator() -> Self {
Self {
role: HandshakeRole::Initiator,
state: HandshakeState::Initial,
server_kp: None,
master_secret: None,
server_identity: None,
trusted_server: None,
client_hello_wire: None,
initiator_client_hello: None,
}
}
pub fn new_initiator_trust_server(trusted_server: VerifyingKey) -> Self {
Self {
role: HandshakeRole::Initiator,
state: HandshakeState::Initial,
server_kp: None,
master_secret: None,
server_identity: None,
trusted_server: Some(trusted_server),
client_hello_wire: None,
initiator_client_hello: None,
}
}
pub fn new_responder() -> Self {
Self {
role: HandshakeRole::Responder,
state: HandshakeState::Initial,
server_kp: None,
master_secret: None,
server_identity: None,
trusted_server: None,
client_hello_wire: None,
initiator_client_hello: None,
}
}
pub fn new_responder_with_identity(server_identity: SigningKey) -> Self {
Self {
role: HandshakeRole::Responder,
state: HandshakeState::Initial,
server_kp: None,
master_secret: None,
server_identity: Some(server_identity),
trusted_server: None,
client_hello_wire: None,
initiator_client_hello: None,
}
}
#[must_use]
pub fn role(&self) -> HandshakeRole {
self.role
}
#[must_use]
pub fn master_secret(&self) -> Option<[u8; 32]> {
self.master_secret
}
pub fn client_hello(&mut self) -> Result<Vec<u8>> {
if self.role != HandshakeRole::Initiator {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"only initiator sends ClientHello".into(),
)));
}
match self.state {
HandshakeState::Initial => {}
_ => {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"invalid initiator state for ClientHello".into(),
)));
}
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let mut nonce = [0u8; NONCE_LEN];
rand::rng().fill_bytes(&mut nonce);
let mut payload = Vec::with_capacity(8 + NONCE_LEN);
payload.extend_from_slice(×tamp.to_be_bytes());
payload.extend_from_slice(&nonce);
let wire = encode_msg(MSG_CLIENT_HELLO, &payload);
self.initiator_client_hello = Some(wire.clone());
self.state = HandshakeState::AwaitingServerHello;
Ok(wire)
}
pub fn server_hello(&mut self, client_hello: &[u8]) -> Result<Vec<u8>> {
if self.role != HandshakeRole::Responder {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"only responder handles ClientHello".into(),
)));
}
match self.state {
HandshakeState::Initial => {}
_ => {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"invalid responder state for server_hello".into(),
)));
}
}
let _payload = decode_msg(MSG_CLIENT_HELLO, client_hello)?;
self.client_hello_wire = Some(client_hello.to_vec());
let server_kp = HybridKem::generate_keypair();
let pk_bytes = server_kp.public.to_bytes();
self.server_kp = Some(server_kp);
self.state = HandshakeState::AwaitingClientFinished;
let payload = if let Some(ref sk) = self.server_identity {
let ch = self.client_hello_wire.as_ref().ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"client hello wire missing".into(),
))
})?;
let digest = transcript_digest(ch, &pk_bytes);
let sig = sk.sign(&digest);
let mut p = pk_bytes;
p.extend_from_slice(&sig.to_bytes());
p
} else {
pk_bytes
};
Ok(encode_msg(MSG_SERVER_HELLO, &payload))
}
pub fn finalize(&mut self, server_hello: &[u8]) -> Result<Vec<u8>> {
if self.role != HandshakeRole::Initiator {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"only initiator calls finalize".into(),
)));
}
match self.state {
HandshakeState::AwaitingServerHello => {}
_ => {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"initiator not awaiting ServerHello".into(),
)));
}
}
let raw = decode_msg(MSG_SERVER_HELLO, server_hello)?;
let expect_sig = self.trusted_server.is_some();
let (server_pk, sig) = parse_server_hello_payload(raw, expect_sig)?;
if let Some(ref vk) = self.trusted_server {
let sig = sig.ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"server signature missing".into(),
))
})?;
let ch = self.initiator_client_hello.as_ref().ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"client hello wire missing".into(),
))
})?;
let digest = transcript_digest(ch, &server_pk.to_bytes());
vk.verify(&digest, &sig).map_err(|_| {
SrxError::Session(SessionError::HandshakeFailed(
"Ed25519 server signature verification failed".into(),
))
})?;
} else if sig.is_some() {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"unexpected server signature (no pinned server key)".into(),
)));
}
let (encap, master) = HybridKem::encapsulate(&server_pk)?;
self.master_secret = Some(master);
self.state = HandshakeState::Completed;
let fin = encap.to_bytes();
Ok(encode_msg(MSG_CLIENT_FINISHED, &fin))
}
pub fn server_finish(&mut self, client_finished: &[u8]) -> Result<()> {
if self.role != HandshakeRole::Responder {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"only responder calls server_finish".into(),
)));
}
match self.state {
HandshakeState::AwaitingClientFinished => {}
_ => {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"responder not awaiting ClientFinished".into(),
)));
}
}
let enc_raw = decode_msg(MSG_CLIENT_FINISHED, client_finished)?;
let encap = EncapsulatedKey::from_bytes(enc_raw)?;
let mut server_kp = self.server_kp.take().ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"server keypair missing".into(),
))
})?;
let master = HybridKem::decapsulate(&mut server_kp, &encap)?;
self.master_secret = Some(master);
self.server_kp = None;
self.state = HandshakeState::Completed;
Ok(())
}
}
fn encode_msg(msg_type: u8, payload: &[u8]) -> Vec<u8> {
let mut v = Vec::with_capacity(4 + 1 + 1 + 4 + payload.len());
v.extend_from_slice(MAGIC);
v.push(VERSION);
v.push(msg_type);
v.extend_from_slice(&(payload.len() as u32).to_be_bytes());
v.extend_from_slice(payload);
v
}
fn decode_msg(expected_type: u8, data: &[u8]) -> Result<&[u8]> {
if data.len() < 10 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"message too short".into(),
)));
}
if data[..4] != *MAGIC {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"bad magic".into(),
)));
}
if data[4] != VERSION {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"bad version".into(),
)));
}
if data[5] != expected_type {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"unexpected message type".into(),
)));
}
let len = u32::from_be_bytes(data[6..10].try_into().expect("len")) as usize;
if data.len() != 10 + len {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"length mismatch".into(),
)));
}
Ok(&data[10..])
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
fn random_signing_key() -> SigningKey {
let mut seed = [0u8; 32];
rand::rng().fill_bytes(&mut seed);
SigningKey::from_bytes(&seed)
}
#[test]
fn full_handshake_matches_master_secret() {
let mut client = Handshake::new_initiator();
let mut server = Handshake::new_responder();
let ch = client.client_hello().unwrap();
let sh = server.server_hello(&ch).unwrap();
let cf = client.finalize(&sh).unwrap();
server.server_finish(&cf).unwrap();
assert_eq!(
client.master_secret().expect("client master"),
server.master_secret().expect("server master")
);
}
#[test]
fn signed_server_hello_roundtrip() {
let sk = random_signing_key();
let vk = sk.verifying_key();
let mut client = Handshake::new_initiator_trust_server(vk);
let mut server = Handshake::new_responder_with_identity(sk);
let ch = client.client_hello().unwrap();
let sh = server.server_hello(&ch).unwrap();
let cf = client.finalize(&sh).unwrap();
server.server_finish(&cf).unwrap();
assert_eq!(
client.master_secret().expect("client master"),
server.master_secret().expect("server master")
);
}
#[test]
fn wrong_server_identity_fails() {
let sk = random_signing_key();
let other = random_signing_key();
let mut client = Handshake::new_initiator_trust_server(other.verifying_key());
let mut server = Handshake::new_responder_with_identity(sk);
let ch = client.client_hello().unwrap();
let sh = server.server_hello(&ch).unwrap();
assert!(client.finalize(&sh).is_err());
}
#[test]
fn wrong_message_type_fails() {
let mut client = Handshake::new_initiator();
let ch = client.client_hello().unwrap();
let bad = {
let mut m = ch.clone();
if m.len() > 5 {
m[5] = 99;
}
m
};
let mut server = Handshake::new_responder();
assert!(server.server_hello(&bad).is_err());
}
}