use crate::PeerIdentity;
use crate::noise::{self, NoiseError, NoiseSession};
use crate::transport::{LinkDirection, LinkId, LinkStats, TransportAddr, TransportId};
use crate::utils::index::SessionIndex;
use secp256k1::Keypair;
use std::fmt;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HandshakeState {
Initial,
SentMsg1,
ReceivedMsg1,
Complete,
Failed,
}
impl HandshakeState {
pub fn is_in_progress(&self) -> bool {
matches!(
self,
HandshakeState::Initial | HandshakeState::SentMsg1 | HandshakeState::ReceivedMsg1
)
}
pub fn is_complete(&self) -> bool {
matches!(self, HandshakeState::Complete)
}
pub fn is_failed(&self) -> bool {
matches!(self, HandshakeState::Failed)
}
}
impl fmt::Display for HandshakeState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
HandshakeState::Initial => "initial",
HandshakeState::SentMsg1 => "sent_msg1",
HandshakeState::ReceivedMsg1 => "received_msg1",
HandshakeState::Complete => "complete",
HandshakeState::Failed => "failed",
};
write!(f, "{}", s)
}
}
pub struct PeerConnection {
link_id: LinkId,
direction: LinkDirection,
handshake_state: HandshakeState,
expected_identity: Option<PeerIdentity>,
noise_handshake: Option<noise::HandshakeState>,
noise_session: Option<NoiseSession>,
started_at: u64,
last_activity: u64,
link_stats: LinkStats,
our_index: Option<SessionIndex>,
their_index: Option<SessionIndex>,
transport_id: Option<TransportId>,
source_addr: Option<TransportAddr>,
remote_epoch: Option<[u8; 8]>,
handshake_msg1: Option<Vec<u8>>,
handshake_msg2: Option<Vec<u8>>,
resend_count: u32,
next_resend_at_ms: u64,
}
impl PeerConnection {
pub fn outbound(
link_id: LinkId,
expected_identity: PeerIdentity,
current_time_ms: u64,
) -> Self {
Self {
link_id,
direction: LinkDirection::Outbound,
handshake_state: HandshakeState::Initial,
expected_identity: Some(expected_identity),
noise_handshake: None,
noise_session: None,
started_at: current_time_ms,
last_activity: current_time_ms,
link_stats: LinkStats::new(),
our_index: None,
their_index: None,
transport_id: None,
source_addr: None,
remote_epoch: None,
handshake_msg1: None,
handshake_msg2: None,
resend_count: 0,
next_resend_at_ms: 0,
}
}
pub fn inbound(link_id: LinkId, current_time_ms: u64) -> Self {
Self {
link_id,
direction: LinkDirection::Inbound,
handshake_state: HandshakeState::Initial,
expected_identity: None,
noise_handshake: None,
noise_session: None,
started_at: current_time_ms,
last_activity: current_time_ms,
link_stats: LinkStats::new(),
our_index: None,
their_index: None,
transport_id: None,
source_addr: None,
remote_epoch: None,
handshake_msg1: None,
handshake_msg2: None,
resend_count: 0,
next_resend_at_ms: 0,
}
}
pub fn inbound_with_transport(
link_id: LinkId,
transport_id: TransportId,
source_addr: TransportAddr,
current_time_ms: u64,
) -> Self {
Self {
link_id,
direction: LinkDirection::Inbound,
handshake_state: HandshakeState::Initial,
expected_identity: None,
noise_handshake: None,
noise_session: None,
started_at: current_time_ms,
last_activity: current_time_ms,
link_stats: LinkStats::new(),
our_index: None,
their_index: None,
transport_id: Some(transport_id),
source_addr: Some(source_addr),
remote_epoch: None,
handshake_msg1: None,
handshake_msg2: None,
resend_count: 0,
next_resend_at_ms: 0,
}
}
pub fn link_id(&self) -> LinkId {
self.link_id
}
pub fn direction(&self) -> LinkDirection {
self.direction
}
pub fn handshake_state(&self) -> HandshakeState {
self.handshake_state
}
pub fn expected_identity(&self) -> Option<&PeerIdentity> {
self.expected_identity.as_ref()
}
pub fn is_outbound(&self) -> bool {
self.direction == LinkDirection::Outbound
}
pub fn is_inbound(&self) -> bool {
self.direction == LinkDirection::Inbound
}
pub fn is_in_progress(&self) -> bool {
self.handshake_state.is_in_progress()
}
pub fn is_complete(&self) -> bool {
self.handshake_state.is_complete()
}
pub fn is_failed(&self) -> bool {
self.handshake_state.is_failed()
}
pub fn started_at(&self) -> u64 {
self.started_at
}
pub fn last_activity(&self) -> u64 {
self.last_activity
}
pub fn duration(&self, current_time_ms: u64) -> u64 {
current_time_ms.saturating_sub(self.started_at)
}
pub fn idle_time(&self, current_time_ms: u64) -> u64 {
current_time_ms.saturating_sub(self.last_activity)
}
pub fn link_stats(&self) -> &LinkStats {
&self.link_stats
}
pub fn link_stats_mut(&mut self) -> &mut LinkStats {
&mut self.link_stats
}
pub fn our_index(&self) -> Option<SessionIndex> {
self.our_index
}
pub fn set_our_index(&mut self, index: SessionIndex) {
self.our_index = Some(index);
}
pub fn their_index(&self) -> Option<SessionIndex> {
self.their_index
}
pub fn set_their_index(&mut self, index: SessionIndex) {
self.their_index = Some(index);
}
pub fn transport_id(&self) -> Option<TransportId> {
self.transport_id
}
pub fn set_transport_id(&mut self, id: TransportId) {
self.transport_id = Some(id);
}
pub fn source_addr(&self) -> Option<&TransportAddr> {
self.source_addr.as_ref()
}
pub fn set_source_addr(&mut self, addr: TransportAddr) {
self.source_addr = Some(addr);
}
pub fn remote_epoch(&self) -> Option<[u8; 8]> {
self.remote_epoch
}
pub fn set_handshake_msg1(&mut self, msg1: Vec<u8>, first_resend_at_ms: u64) {
self.handshake_msg1 = Some(msg1);
self.resend_count = 0;
self.next_resend_at_ms = first_resend_at_ms;
}
pub fn set_handshake_msg2(&mut self, msg2: Vec<u8>) {
self.handshake_msg2 = Some(msg2);
}
pub fn handshake_msg1(&self) -> Option<&[u8]> {
self.handshake_msg1.as_deref()
}
pub fn handshake_msg2(&self) -> Option<&[u8]> {
self.handshake_msg2.as_deref()
}
pub fn resend_count(&self) -> u32 {
self.resend_count
}
pub fn next_resend_at_ms(&self) -> u64 {
self.next_resend_at_ms
}
pub fn record_resend(&mut self, next_resend_at_ms: u64) {
self.resend_count += 1;
self.next_resend_at_ms = next_resend_at_ms;
}
pub fn start_handshake(
&mut self,
our_keypair: Keypair,
epoch: [u8; 8],
current_time_ms: u64,
) -> Result<Vec<u8>, NoiseError> {
if self.direction != LinkDirection::Outbound {
return Err(NoiseError::WrongState {
expected: "outbound connection".to_string(),
got: "inbound connection".to_string(),
});
}
if self.handshake_state != HandshakeState::Initial {
return Err(NoiseError::WrongState {
expected: "initial state".to_string(),
got: self.handshake_state.to_string(),
});
}
let remote_static = self
.expected_identity
.as_ref()
.expect("outbound must have expected identity")
.pubkey_full();
let mut hs = noise::HandshakeState::new_initiator(our_keypair, remote_static);
hs.set_local_epoch(epoch);
let msg1 = hs.write_message_1()?;
self.noise_handshake = Some(hs);
self.handshake_state = HandshakeState::SentMsg1;
self.last_activity = current_time_ms;
Ok(msg1)
}
pub fn receive_handshake_init(
&mut self,
our_keypair: Keypair,
epoch: [u8; 8],
message: &[u8],
current_time_ms: u64,
) -> Result<Vec<u8>, NoiseError> {
if self.direction != LinkDirection::Inbound {
return Err(NoiseError::WrongState {
expected: "inbound connection".to_string(),
got: "outbound connection".to_string(),
});
}
if self.handshake_state != HandshakeState::Initial {
return Err(NoiseError::WrongState {
expected: "initial state".to_string(),
got: self.handshake_state.to_string(),
});
}
let mut hs = noise::HandshakeState::new_responder(our_keypair);
hs.set_local_epoch(epoch);
hs.read_message_1(message)?;
let remote_static = *hs
.remote_static()
.expect("remote static available after msg1");
self.expected_identity = Some(PeerIdentity::from_pubkey_full(remote_static));
self.remote_epoch = hs.remote_epoch();
let msg2 = hs.write_message_2()?;
let session = hs.into_session()?;
self.noise_session = Some(session);
self.handshake_state = HandshakeState::Complete;
self.last_activity = current_time_ms;
Ok(msg2)
}
pub fn complete_handshake(
&mut self,
message: &[u8],
current_time_ms: u64,
) -> Result<(), NoiseError> {
if self.handshake_state != HandshakeState::SentMsg1 {
return Err(NoiseError::WrongState {
expected: "sent_msg1 state".to_string(),
got: self.handshake_state.to_string(),
});
}
let mut hs = self
.noise_handshake
.take()
.expect("noise handshake must exist in SentMsg1 state");
hs.read_message_2(message)?;
self.remote_epoch = hs.remote_epoch();
let session = hs.into_session()?;
self.noise_session = Some(session);
self.handshake_state = HandshakeState::Complete;
self.last_activity = current_time_ms;
Ok(())
}
pub fn take_session(&mut self) -> Option<NoiseSession> {
if self.handshake_state == HandshakeState::Complete {
self.noise_session.take()
} else {
None
}
}
pub fn has_session(&self) -> bool {
self.handshake_state == HandshakeState::Complete && self.noise_session.is_some()
}
pub fn mark_failed(&mut self) {
self.handshake_state = HandshakeState::Failed;
self.noise_handshake = None;
}
pub fn touch(&mut self, current_time_ms: u64) {
self.last_activity = current_time_ms;
}
pub fn is_timed_out(&self, current_time_ms: u64, timeout_ms: u64) -> bool {
self.idle_time(current_time_ms) > timeout_ms
}
}
impl fmt::Debug for PeerConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PeerConnection")
.field("link_id", &self.link_id)
.field("direction", &self.direction)
.field("handshake_state", &self.handshake_state)
.field("expected_identity", &self.expected_identity)
.field("has_noise_handshake", &self.noise_handshake.is_some())
.field("has_noise_session", &self.noise_session.is_some())
.field("our_index", &self.our_index)
.field("their_index", &self.their_index)
.field("transport_id", &self.transport_id)
.field("started_at", &self.started_at)
.field("last_activity", &self.last_activity)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Identity;
use rand::Rng;
fn make_peer_identity() -> PeerIdentity {
let identity = Identity::generate();
PeerIdentity::from_pubkey(identity.pubkey())
}
fn make_keypair() -> Keypair {
let identity = Identity::generate();
identity.keypair()
}
fn make_epoch() -> [u8; 8] {
let mut epoch = [0u8; 8];
rand::rng().fill_bytes(&mut epoch);
epoch
}
#[test]
fn test_handshake_state_properties() {
assert!(HandshakeState::Initial.is_in_progress());
assert!(HandshakeState::SentMsg1.is_in_progress());
assert!(HandshakeState::ReceivedMsg1.is_in_progress());
assert!(!HandshakeState::Complete.is_in_progress());
assert!(!HandshakeState::Failed.is_in_progress());
assert!(HandshakeState::Complete.is_complete());
assert!(HandshakeState::Failed.is_failed());
}
#[test]
fn test_outbound_connection() {
let identity = make_peer_identity();
let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
assert!(conn.is_outbound());
assert!(!conn.is_inbound());
assert_eq!(conn.handshake_state(), HandshakeState::Initial);
assert!(conn.expected_identity().is_some());
assert_eq!(conn.started_at(), 1000);
}
#[test]
fn test_inbound_connection() {
let conn = PeerConnection::inbound(LinkId::new(2), 2000);
assert!(conn.is_inbound());
assert!(!conn.is_outbound());
assert_eq!(conn.handshake_state(), HandshakeState::Initial);
assert!(conn.expected_identity().is_none());
assert_eq!(conn.started_at(), 2000);
}
#[test]
fn test_full_handshake_flow() {
let initiator_identity = Identity::generate();
let responder_identity = Identity::generate();
let initiator_keypair = initiator_identity.keypair();
let responder_keypair = responder_identity.keypair();
let initiator_epoch = make_epoch();
let responder_epoch = make_epoch();
let responder_peer_id = PeerIdentity::from_pubkey_full(responder_identity.pubkey_full());
let mut initiator_conn = PeerConnection::outbound(LinkId::new(1), responder_peer_id, 1000);
let mut responder_conn = PeerConnection::inbound(LinkId::new(2), 1000);
let msg1 = initiator_conn
.start_handshake(initiator_keypair, initiator_epoch, 1100)
.unwrap();
assert_eq!(initiator_conn.handshake_state(), HandshakeState::SentMsg1);
let msg2 = responder_conn
.receive_handshake_init(responder_keypair, responder_epoch, &msg1, 1200)
.unwrap();
assert_eq!(responder_conn.handshake_state(), HandshakeState::Complete);
let discovered = responder_conn.expected_identity().unwrap();
assert_eq!(discovered.pubkey(), initiator_identity.pubkey());
assert_eq!(responder_conn.remote_epoch(), Some(initiator_epoch));
initiator_conn.complete_handshake(&msg2, 1300).unwrap();
assert_eq!(initiator_conn.handshake_state(), HandshakeState::Complete);
assert_eq!(initiator_conn.remote_epoch(), Some(responder_epoch));
assert!(initiator_conn.has_session());
assert!(responder_conn.has_session());
let mut init_session = initiator_conn.take_session().unwrap();
let mut resp_session = responder_conn.take_session().unwrap();
let plaintext = b"test message";
let ciphertext = init_session.encrypt(plaintext).unwrap();
let decrypted = resp_session.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_connection_timing() {
let identity = make_peer_identity();
let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
assert_eq!(conn.duration(1500), 500);
assert_eq!(conn.idle_time(1500), 500);
assert!(!conn.is_timed_out(1500, 1000));
assert!(conn.is_timed_out(2500, 1000));
}
#[test]
fn test_connection_failure() {
let identity = make_peer_identity();
let mut conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
conn.mark_failed();
assert!(conn.is_failed());
assert!(!conn.is_in_progress());
assert!(!conn.is_complete());
}
#[test]
fn test_wrong_direction_errors() {
let identity = make_peer_identity();
let keypair = make_keypair();
let mut outbound = PeerConnection::outbound(LinkId::new(1), identity, 1000);
assert!(
outbound
.receive_handshake_init(keypair, make_epoch(), &[0u8; 106], 1100)
.is_err()
);
let mut inbound = PeerConnection::inbound(LinkId::new(2), 1000);
assert!(
inbound
.start_handshake(keypair, make_epoch(), 1100)
.is_err()
);
}
}