use crate::crypto::pq::{MlKemKeyPair, MlKemPublicKey, SphincsPublicKey};
use crate::crypto::noise::NoiseSession;
use crate::crypto::ratchet::PqTripleRatchet;
use crate::protocol::PeerCapabilities;
use crate::{Error, Result, Identity};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeState {
New,
Initiated,
Responding,
Complete,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HandshakeMessage {
Hello {
capabilities: PeerCapabilities,
mlkem_pk: Option<Vec<u8>>,
x25519_pk: Option<Vec<u8>>,
identity_pk: Vec<u8>,
},
Response {
capabilities: PeerCapabilities,
mlkem_ct: Option<Vec<u8>>,
x25519_pk: Option<Vec<u8>>,
identity_pk: Vec<u8>,
signature: Vec<u8>,
},
Confirm {
signature: Vec<u8>,
},
}
pub struct Handshake {
our_identity: Identity,
their_identity: Option<Identity>,
state: HandshakeState,
initiator: bool,
mlkem_keypair: Option<MlKemKeyPair>,
use_pqc: bool,
transcript: Vec<u8>,
}
impl Handshake {
pub fn initiator(identity: Identity) -> Self {
Self {
our_identity: identity,
their_identity: None,
state: HandshakeState::New,
initiator: true,
mlkem_keypair: None,
use_pqc: true,
transcript: Vec::new(),
}
}
pub fn responder(identity: Identity) -> Self {
Self {
our_identity: identity,
their_identity: None,
state: HandshakeState::New,
initiator: false,
mlkem_keypair: None,
use_pqc: true,
transcript: Vec::new(),
}
}
pub fn state(&self) -> HandshakeState {
self.state
}
pub fn is_complete(&self) -> bool {
self.state == HandshakeState::Complete
}
pub fn create_hello(&mut self, our_caps: &PeerCapabilities) -> Result<HandshakeMessage> {
if self.state != HandshakeState::New {
return Err(Error::KeyExchange("Invalid state for hello".into()));
}
let mlkem = MlKemKeyPair::generate()?;
let mlkem_pk = mlkem.public_key().as_bytes().to_vec();
self.mlkem_keypair = Some(mlkem);
let msg = HandshakeMessage::Hello {
capabilities: our_caps.clone(),
mlkem_pk: Some(mlkem_pk),
x25519_pk: None, identity_pk: self.our_identity.public_key().as_bytes().to_vec(),
};
let serialized = bincode::serialize(&msg)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.transcript.extend(&serialized);
self.state = HandshakeState::Initiated;
Ok(msg)
}
pub fn process_hello(
&mut self,
msg: HandshakeMessage,
our_caps: &PeerCapabilities,
) -> Result<HandshakeMessage> {
if self.state != HandshakeState::New {
return Err(Error::KeyExchange("Invalid state for processing hello".into()));
}
let (their_caps, mlkem_pk_bytes, identity_pk_bytes) = match &msg {
HandshakeMessage::Hello {
capabilities,
mlkem_pk,
identity_pk,
..
} => (capabilities.clone(), mlkem_pk.clone(), identity_pk.clone()),
_ => return Err(Error::InvalidMessage("Expected Hello".into())),
};
let identity_pk = SphincsPublicKey::from_bytes(&identity_pk_bytes)?;
self.their_identity = Some(Identity::from_public_key(identity_pk));
self.use_pqc = our_caps.pq_ratchet && their_caps.pq_ratchet;
let mlkem = MlKemKeyPair::generate()?;
let mlkem_ct = if self.use_pqc {
if let Some(pk_bytes) = mlkem_pk_bytes {
let their_pk = MlKemPublicKey::from_bytes(&pk_bytes)?;
let (ct, _ss) = their_pk.encapsulate()?;
Some(ct.as_bytes().to_vec())
} else {
None
}
} else {
None
};
self.mlkem_keypair = Some(mlkem);
let serialized = bincode::serialize(&msg)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.transcript.extend(&serialized);
let signature = self.our_identity.sign(&self.transcript)?;
let response = HandshakeMessage::Response {
capabilities: our_caps.clone(),
mlkem_ct,
x25519_pk: None,
identity_pk: self.our_identity.public_key().as_bytes().to_vec(),
signature,
};
let serialized = bincode::serialize(&response)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.transcript.extend(&serialized);
self.state = HandshakeState::Responding;
Ok(response)
}
pub fn process_response(&mut self, msg: HandshakeMessage) -> Result<HandshakeMessage> {
if self.state != HandshakeState::Initiated {
return Err(Error::KeyExchange("Invalid state for processing response".into()));
}
let (their_caps, _mlkem_ct, identity_pk_bytes, their_sig) = match &msg {
HandshakeMessage::Response {
capabilities,
mlkem_ct,
identity_pk,
signature,
..
} => (capabilities.clone(), mlkem_ct.clone(), identity_pk.clone(), signature.clone()),
_ => return Err(Error::InvalidMessage("Expected Response".into())),
};
let identity_pk = SphincsPublicKey::from_bytes(&identity_pk_bytes)?;
self.their_identity = Some(Identity::from_public_key(identity_pk.clone()));
if !identity_pk.verify(&self.transcript, &their_sig)? {
self.state = HandshakeState::Failed;
return Err(Error::InvalidSignature);
}
self.use_pqc = their_caps.pq_ratchet;
let serialized = bincode::serialize(&msg)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.transcript.extend(&serialized);
let signature = self.our_identity.sign(&self.transcript)?;
let confirm = HandshakeMessage::Confirm { signature };
self.state = HandshakeState::Complete;
Ok(confirm)
}
pub fn process_confirm(&mut self, msg: HandshakeMessage) -> Result<()> {
if self.state != HandshakeState::Responding {
return Err(Error::KeyExchange("Invalid state for processing confirm".into()));
}
let their_sig = match &msg {
HandshakeMessage::Confirm { signature } => signature.clone(),
_ => return Err(Error::InvalidMessage("Expected Confirm".into())),
};
let their_identity = self.their_identity.as_ref()
.ok_or_else(|| Error::KeyExchange("No their identity".into()))?;
if !their_identity.verify(&self.transcript, &their_sig)? {
self.state = HandshakeState::Failed;
return Err(Error::InvalidSignature);
}
self.state = HandshakeState::Complete;
Ok(())
}
pub fn their_identity(&self) -> Option<&Identity> {
self.their_identity.as_ref()
}
pub fn uses_pqc(&self) -> bool {
self.use_pqc
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_creation() {
let identity = Identity::generate().unwrap();
let hs = Handshake::initiator(identity);
assert_eq!(hs.state(), HandshakeState::New);
}
#[test]
fn test_create_hello() {
let identity = Identity::generate().unwrap();
let mut hs = Handshake::initiator(identity);
let caps = PeerCapabilities::default();
let msg = hs.create_hello(&caps).unwrap();
assert!(matches!(msg, HandshakeMessage::Hello { .. }));
assert_eq!(hs.state(), HandshakeState::Initiated);
}
}