Skip to main content

fips_core/peer/
connection.rs

1//! Peer Connection (Handshake Phase)
2//!
3//! Represents an in-progress connection before authentication completes.
4//! PeerConnection tracks the Noise IK handshake state and transitions to
5//! ActivePeer upon successful authentication.
6
7use crate::PeerIdentity;
8use crate::noise::{self, NoiseError, NoiseSession};
9use crate::transport::{LinkDirection, LinkId, LinkStats, TransportAddr, TransportId};
10use crate::utils::index::SessionIndex;
11use secp256k1::Keypair;
12use std::fmt;
13
14/// Handshake protocol state machine.
15///
16/// For Noise IK pattern:
17/// - Initiator: Initial → SentMsg1 → Complete
18/// - Responder: Initial → ReceivedMsg1 → Complete
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum HandshakeState {
21    /// Initial state, ready to start handshake.
22    Initial,
23    /// Initiator: Sent message 1, awaiting message 2.
24    SentMsg1,
25    /// Responder: Received message 1, ready to send message 2.
26    ReceivedMsg1,
27    /// Handshake completed successfully.
28    Complete,
29    /// Handshake failed.
30    Failed,
31}
32
33impl HandshakeState {
34    /// Check if handshake is still in progress.
35    pub fn is_in_progress(&self) -> bool {
36        matches!(
37            self,
38            HandshakeState::Initial | HandshakeState::SentMsg1 | HandshakeState::ReceivedMsg1
39        )
40    }
41
42    /// Check if handshake completed successfully.
43    pub fn is_complete(&self) -> bool {
44        matches!(self, HandshakeState::Complete)
45    }
46
47    /// Check if handshake failed.
48    pub fn is_failed(&self) -> bool {
49        matches!(self, HandshakeState::Failed)
50    }
51}
52
53impl fmt::Display for HandshakeState {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        let s = match self {
56            HandshakeState::Initial => "initial",
57            HandshakeState::SentMsg1 => "sent_msg1",
58            HandshakeState::ReceivedMsg1 => "received_msg1",
59            HandshakeState::Complete => "complete",
60            HandshakeState::Failed => "failed",
61        };
62        write!(f, "{}", s)
63    }
64}
65
66/// A connection in the handshake phase, before authentication completes.
67///
68/// For outbound connections, we know the expected peer identity from config.
69/// For inbound connections, we learn the identity during the Noise handshake.
70pub struct PeerConnection {
71    // === Link Reference ===
72    /// The link carrying this connection.
73    link_id: LinkId,
74
75    /// Connection direction (we initiated or they initiated).
76    direction: LinkDirection,
77
78    // === Handshake State ===
79    /// Current handshake state.
80    handshake_state: HandshakeState,
81
82    /// Expected peer identity (known for outbound, learned for inbound).
83    /// Updated after receiving their static key in the handshake.
84    expected_identity: Option<PeerIdentity>,
85
86    /// Noise handshake state (consumes on completion).
87    noise_handshake: Option<noise::HandshakeState>,
88
89    /// Completed Noise session (available after handshake complete).
90    noise_session: Option<NoiseSession>,
91
92    // === Timing ===
93    /// When the connection attempt started (Unix milliseconds).
94    started_at: u64,
95
96    /// When the last handshake message was sent/received.
97    last_activity: u64,
98
99    // === Statistics ===
100    /// Link statistics during handshake.
101    link_stats: LinkStats,
102
103    // === Wire Protocol Index Tracking ===
104    /// Our sender_idx for this handshake (chosen by us).
105    /// For outbound: included in msg1, used as receiver_idx in msg2 echo.
106    /// For inbound: chosen after processing msg1, included in msg2.
107    our_index: Option<SessionIndex>,
108
109    /// Their sender_idx (learned from their messages).
110    /// For outbound: learned from msg2.
111    /// For inbound: learned from msg1.
112    their_index: Option<SessionIndex>,
113
114    /// Transport ID (for index namespace).
115    transport_id: Option<TransportId>,
116
117    /// Current source address (updated on packet receipt).
118    source_addr: Option<TransportAddr>,
119
120    // === Epoch (Restart Detection) ===
121    /// Remote peer's startup epoch (learned from handshake).
122    remote_epoch: Option<[u8; 8]>,
123
124    // === Handshake Resend ===
125    /// Wire-format msg1 bytes for resend (initiator only).
126    handshake_msg1: Option<Vec<u8>>,
127
128    /// Wire-format msg2 bytes for resend (responder only).
129    handshake_msg2: Option<Vec<u8>>,
130
131    /// Number of resends performed so far.
132    resend_count: u32,
133
134    /// When the next resend should fire (Unix ms). 0 = no resend scheduled.
135    next_resend_at_ms: u64,
136}
137
138impl PeerConnection {
139    /// Create a new outbound connection (we are initiating).
140    ///
141    /// For outbound, we know who we're trying to reach from configuration.
142    /// The Noise handshake will be initialized when `start_handshake` is called.
143    pub fn outbound(
144        link_id: LinkId,
145        expected_identity: PeerIdentity,
146        current_time_ms: u64,
147    ) -> Self {
148        Self {
149            link_id,
150            direction: LinkDirection::Outbound,
151            handshake_state: HandshakeState::Initial,
152            expected_identity: Some(expected_identity),
153            noise_handshake: None,
154            noise_session: None,
155            started_at: current_time_ms,
156            last_activity: current_time_ms,
157
158            link_stats: LinkStats::new(),
159            our_index: None,
160            their_index: None,
161            transport_id: None,
162            source_addr: None,
163            remote_epoch: None,
164            handshake_msg1: None,
165            handshake_msg2: None,
166            resend_count: 0,
167            next_resend_at_ms: 0,
168        }
169    }
170
171    /// Create a new inbound connection (they are initiating).
172    ///
173    /// For inbound, we don't know who they are until we decrypt their
174    /// identity from Noise message 1.
175    pub fn inbound(link_id: LinkId, current_time_ms: u64) -> Self {
176        Self {
177            link_id,
178            direction: LinkDirection::Inbound,
179            handshake_state: HandshakeState::Initial,
180            expected_identity: None,
181            noise_handshake: None,
182            noise_session: None,
183            started_at: current_time_ms,
184            last_activity: current_time_ms,
185
186            link_stats: LinkStats::new(),
187            our_index: None,
188            their_index: None,
189            transport_id: None,
190            source_addr: None,
191            remote_epoch: None,
192            handshake_msg1: None,
193            handshake_msg2: None,
194            resend_count: 0,
195            next_resend_at_ms: 0,
196        }
197    }
198
199    /// Create a new inbound connection with transport information.
200    ///
201    /// Used when processing msg1 where we know the transport and source address.
202    pub fn inbound_with_transport(
203        link_id: LinkId,
204        transport_id: TransportId,
205        source_addr: TransportAddr,
206        current_time_ms: u64,
207    ) -> Self {
208        Self {
209            link_id,
210            direction: LinkDirection::Inbound,
211            handshake_state: HandshakeState::Initial,
212            expected_identity: None,
213            noise_handshake: None,
214            noise_session: None,
215            started_at: current_time_ms,
216            last_activity: current_time_ms,
217
218            link_stats: LinkStats::new(),
219            our_index: None,
220            their_index: None,
221            transport_id: Some(transport_id),
222            source_addr: Some(source_addr),
223            remote_epoch: None,
224            handshake_msg1: None,
225            handshake_msg2: None,
226            resend_count: 0,
227            next_resend_at_ms: 0,
228        }
229    }
230
231    // === Accessors ===
232
233    /// Get the link ID.
234    pub fn link_id(&self) -> LinkId {
235        self.link_id
236    }
237
238    /// Get the connection direction.
239    pub fn direction(&self) -> LinkDirection {
240        self.direction
241    }
242
243    /// Get the handshake state.
244    pub fn handshake_state(&self) -> HandshakeState {
245        self.handshake_state
246    }
247
248    /// Get the expected/learned peer identity, if known.
249    pub fn expected_identity(&self) -> Option<&PeerIdentity> {
250        self.expected_identity.as_ref()
251    }
252
253    /// Check if this is an outbound connection.
254    pub fn is_outbound(&self) -> bool {
255        self.direction == LinkDirection::Outbound
256    }
257
258    /// Check if this is an inbound connection.
259    pub fn is_inbound(&self) -> bool {
260        self.direction == LinkDirection::Inbound
261    }
262
263    /// Check if handshake is in progress.
264    pub fn is_in_progress(&self) -> bool {
265        self.handshake_state.is_in_progress()
266    }
267
268    /// Check if handshake completed.
269    pub fn is_complete(&self) -> bool {
270        self.handshake_state.is_complete()
271    }
272
273    /// Check if handshake failed.
274    pub fn is_failed(&self) -> bool {
275        self.handshake_state.is_failed()
276    }
277
278    /// When the connection started.
279    pub fn started_at(&self) -> u64 {
280        self.started_at
281    }
282
283    /// When the last activity occurred.
284    pub fn last_activity(&self) -> u64 {
285        self.last_activity
286    }
287
288    /// Connection duration so far.
289    pub fn duration(&self, current_time_ms: u64) -> u64 {
290        current_time_ms.saturating_sub(self.started_at)
291    }
292
293    /// Time since last activity.
294    pub fn idle_time(&self, current_time_ms: u64) -> u64 {
295        current_time_ms.saturating_sub(self.last_activity)
296    }
297
298    /// Get link statistics.
299    pub fn link_stats(&self) -> &LinkStats {
300        &self.link_stats
301    }
302
303    /// Get mutable link statistics.
304    pub fn link_stats_mut(&mut self) -> &mut LinkStats {
305        &mut self.link_stats
306    }
307
308    // === Index Accessors ===
309
310    /// Get our session index (if set).
311    pub fn our_index(&self) -> Option<SessionIndex> {
312        self.our_index
313    }
314
315    /// Set our session index.
316    pub fn set_our_index(&mut self, index: SessionIndex) {
317        self.our_index = Some(index);
318    }
319
320    /// Get their session index (if known).
321    pub fn their_index(&self) -> Option<SessionIndex> {
322        self.their_index
323    }
324
325    /// Set their session index.
326    pub fn set_their_index(&mut self, index: SessionIndex) {
327        self.their_index = Some(index);
328    }
329
330    /// Get the transport ID (if set).
331    pub fn transport_id(&self) -> Option<TransportId> {
332        self.transport_id
333    }
334
335    /// Set the transport ID.
336    pub fn set_transport_id(&mut self, id: TransportId) {
337        self.transport_id = Some(id);
338    }
339
340    /// Get the source address (if known).
341    pub fn source_addr(&self) -> Option<&TransportAddr> {
342        self.source_addr.as_ref()
343    }
344
345    /// Set the source address.
346    pub fn set_source_addr(&mut self, addr: TransportAddr) {
347        self.source_addr = Some(addr);
348    }
349
350    // === Epoch Accessors ===
351
352    /// Get the remote peer's startup epoch (available after handshake).
353    pub fn remote_epoch(&self) -> Option<[u8; 8]> {
354        self.remote_epoch
355    }
356
357    // === Handshake Resend ===
358
359    /// Store the wire-format msg1 bytes for resend and schedule the first resend.
360    pub fn set_handshake_msg1(&mut self, msg1: Vec<u8>, first_resend_at_ms: u64) {
361        self.handshake_msg1 = Some(msg1);
362        self.resend_count = 0;
363        self.next_resend_at_ms = first_resend_at_ms;
364    }
365
366    /// Store the wire-format msg2 bytes for resend on duplicate msg1.
367    pub fn set_handshake_msg2(&mut self, msg2: Vec<u8>) {
368        self.handshake_msg2 = Some(msg2);
369    }
370
371    /// Get the stored msg1 bytes (if any).
372    pub fn handshake_msg1(&self) -> Option<&[u8]> {
373        self.handshake_msg1.as_deref()
374    }
375
376    /// Get the stored msg2 bytes (if any).
377    pub fn handshake_msg2(&self) -> Option<&[u8]> {
378        self.handshake_msg2.as_deref()
379    }
380
381    /// Number of resends performed.
382    pub fn resend_count(&self) -> u32 {
383        self.resend_count
384    }
385
386    /// When the next resend is scheduled (Unix ms).
387    pub fn next_resend_at_ms(&self) -> u64 {
388        self.next_resend_at_ms
389    }
390
391    /// Record a resend and schedule the next one.
392    pub fn record_resend(&mut self, next_resend_at_ms: u64) {
393        self.resend_count += 1;
394        self.next_resend_at_ms = next_resend_at_ms;
395    }
396
397    // === Noise Handshake Operations ===
398
399    /// Start the handshake as initiator and generate message 1.
400    ///
401    /// For outbound connections only. Returns the handshake message to send.
402    /// The epoch is our startup epoch, encrypted into msg1 for restart detection.
403    pub fn start_handshake(
404        &mut self,
405        our_keypair: Keypair,
406        epoch: [u8; 8],
407        current_time_ms: u64,
408    ) -> Result<Vec<u8>, NoiseError> {
409        if self.direction != LinkDirection::Outbound {
410            return Err(NoiseError::WrongState {
411                expected: "outbound connection".to_string(),
412                got: "inbound connection".to_string(),
413            });
414        }
415
416        if self.handshake_state != HandshakeState::Initial {
417            return Err(NoiseError::WrongState {
418                expected: "initial state".to_string(),
419                got: self.handshake_state.to_string(),
420            });
421        }
422
423        let remote_static = self
424            .expected_identity
425            .as_ref()
426            .expect("outbound must have expected identity")
427            .pubkey_full();
428
429        let mut hs = noise::HandshakeState::new_initiator(our_keypair, remote_static);
430        hs.set_local_epoch(epoch);
431        let msg1 = hs.write_message_1()?;
432
433        self.noise_handshake = Some(hs);
434        self.handshake_state = HandshakeState::SentMsg1;
435        self.last_activity = current_time_ms;
436
437        Ok(msg1)
438    }
439
440    /// Initialize responder and process incoming message 1.
441    ///
442    /// For inbound connections only. Returns the handshake message 2 to send.
443    /// The epoch is our startup epoch, encrypted into msg2 for restart detection.
444    pub fn receive_handshake_init(
445        &mut self,
446        our_keypair: Keypair,
447        epoch: [u8; 8],
448        message: &[u8],
449        current_time_ms: u64,
450    ) -> Result<Vec<u8>, NoiseError> {
451        if self.direction != LinkDirection::Inbound {
452            return Err(NoiseError::WrongState {
453                expected: "inbound connection".to_string(),
454                got: "outbound connection".to_string(),
455            });
456        }
457
458        if self.handshake_state != HandshakeState::Initial {
459            return Err(NoiseError::WrongState {
460                expected: "initial state".to_string(),
461                got: self.handshake_state.to_string(),
462            });
463        }
464
465        let mut hs = noise::HandshakeState::new_responder(our_keypair);
466        hs.set_local_epoch(epoch);
467
468        // Process message 1 (this reveals the initiator's identity and epoch)
469        hs.read_message_1(message)?;
470
471        // Extract the discovered identity
472        let remote_static = *hs
473            .remote_static()
474            .expect("remote static available after msg1");
475        self.expected_identity = Some(PeerIdentity::from_pubkey_full(remote_static));
476
477        // Capture remote epoch from msg1
478        self.remote_epoch = hs.remote_epoch();
479
480        // Generate message 2
481        let msg2 = hs.write_message_2()?;
482
483        // Handshake is complete for responder
484        let session = hs.into_session()?;
485        self.noise_session = Some(session);
486        self.handshake_state = HandshakeState::Complete;
487        self.last_activity = current_time_ms;
488
489        Ok(msg2)
490    }
491
492    /// Complete the handshake by processing message 2.
493    ///
494    /// For outbound connections only (initiator completing handshake).
495    pub fn complete_handshake(
496        &mut self,
497        message: &[u8],
498        current_time_ms: u64,
499    ) -> Result<(), NoiseError> {
500        if self.handshake_state != HandshakeState::SentMsg1 {
501            return Err(NoiseError::WrongState {
502                expected: "sent_msg1 state".to_string(),
503                got: self.handshake_state.to_string(),
504            });
505        }
506
507        let mut hs = self
508            .noise_handshake
509            .take()
510            .expect("noise handshake must exist in SentMsg1 state");
511
512        hs.read_message_2(message)?;
513
514        // Capture remote epoch from msg2
515        self.remote_epoch = hs.remote_epoch();
516
517        let session = hs.into_session()?;
518        self.noise_session = Some(session);
519        self.handshake_state = HandshakeState::Complete;
520        self.last_activity = current_time_ms;
521
522        Ok(())
523    }
524
525    /// Take the completed Noise session.
526    ///
527    /// Returns the NoiseSession for use in ActivePeer. Can only be called
528    /// once after handshake completes.
529    pub fn take_session(&mut self) -> Option<NoiseSession> {
530        if self.handshake_state == HandshakeState::Complete {
531            self.noise_session.take()
532        } else {
533            None
534        }
535    }
536
537    /// Check if we have a completed session ready to take.
538    pub fn has_session(&self) -> bool {
539        self.handshake_state == HandshakeState::Complete && self.noise_session.is_some()
540    }
541
542    // === State Transitions (for manual control if needed) ===
543
544    /// Mark handshake as failed.
545    pub fn mark_failed(&mut self) {
546        self.handshake_state = HandshakeState::Failed;
547        self.noise_handshake = None;
548    }
549
550    /// Update last activity timestamp.
551    pub fn touch(&mut self, current_time_ms: u64) {
552        self.last_activity = current_time_ms;
553    }
554
555    // === Validation ===
556
557    /// Check if the connection has timed out.
558    pub fn is_timed_out(&self, current_time_ms: u64, timeout_ms: u64) -> bool {
559        self.idle_time(current_time_ms) > timeout_ms
560    }
561}
562
563impl fmt::Debug for PeerConnection {
564    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565        f.debug_struct("PeerConnection")
566            .field("link_id", &self.link_id)
567            .field("direction", &self.direction)
568            .field("handshake_state", &self.handshake_state)
569            .field("expected_identity", &self.expected_identity)
570            .field("has_noise_handshake", &self.noise_handshake.is_some())
571            .field("has_noise_session", &self.noise_session.is_some())
572            .field("our_index", &self.our_index)
573            .field("their_index", &self.their_index)
574            .field("transport_id", &self.transport_id)
575            .field("started_at", &self.started_at)
576            .field("last_activity", &self.last_activity)
577            .finish()
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use crate::Identity;
585    use rand::Rng;
586
587    fn make_peer_identity() -> PeerIdentity {
588        let identity = Identity::generate();
589        PeerIdentity::from_pubkey(identity.pubkey())
590    }
591
592    fn make_keypair() -> Keypair {
593        let identity = Identity::generate();
594        identity.keypair()
595    }
596
597    fn make_epoch() -> [u8; 8] {
598        let mut epoch = [0u8; 8];
599        rand::rng().fill_bytes(&mut epoch);
600        epoch
601    }
602
603    #[test]
604    fn test_handshake_state_properties() {
605        assert!(HandshakeState::Initial.is_in_progress());
606        assert!(HandshakeState::SentMsg1.is_in_progress());
607        assert!(HandshakeState::ReceivedMsg1.is_in_progress());
608        assert!(!HandshakeState::Complete.is_in_progress());
609        assert!(!HandshakeState::Failed.is_in_progress());
610
611        assert!(HandshakeState::Complete.is_complete());
612        assert!(HandshakeState::Failed.is_failed());
613    }
614
615    #[test]
616    fn test_outbound_connection() {
617        let identity = make_peer_identity();
618        let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
619
620        assert!(conn.is_outbound());
621        assert!(!conn.is_inbound());
622        assert_eq!(conn.handshake_state(), HandshakeState::Initial);
623        assert!(conn.expected_identity().is_some());
624        assert_eq!(conn.started_at(), 1000);
625    }
626
627    #[test]
628    fn test_inbound_connection() {
629        let conn = PeerConnection::inbound(LinkId::new(2), 2000);
630
631        assert!(conn.is_inbound());
632        assert!(!conn.is_outbound());
633        assert_eq!(conn.handshake_state(), HandshakeState::Initial);
634        assert!(conn.expected_identity().is_none());
635        assert_eq!(conn.started_at(), 2000);
636    }
637
638    #[test]
639    fn test_full_handshake_flow() {
640        // Create identities
641        let initiator_identity = Identity::generate();
642        let responder_identity = Identity::generate();
643
644        let initiator_keypair = initiator_identity.keypair();
645        let responder_keypair = responder_identity.keypair();
646        let initiator_epoch = make_epoch();
647        let responder_epoch = make_epoch();
648
649        // Use from_pubkey_full to preserve parity for ECDH
650        let responder_peer_id = PeerIdentity::from_pubkey_full(responder_identity.pubkey_full());
651
652        // Create connections
653        let mut initiator_conn = PeerConnection::outbound(LinkId::new(1), responder_peer_id, 1000);
654        let mut responder_conn = PeerConnection::inbound(LinkId::new(2), 1000);
655
656        // Initiator starts handshake
657        let msg1 = initiator_conn
658            .start_handshake(initiator_keypair, initiator_epoch, 1100)
659            .unwrap();
660        assert_eq!(initiator_conn.handshake_state(), HandshakeState::SentMsg1);
661
662        // Responder processes msg1 and sends msg2
663        let msg2 = responder_conn
664            .receive_handshake_init(responder_keypair, responder_epoch, &msg1, 1200)
665            .unwrap();
666        assert_eq!(responder_conn.handshake_state(), HandshakeState::Complete);
667
668        // Responder learned initiator's identity
669        let discovered = responder_conn.expected_identity().unwrap();
670        assert_eq!(discovered.pubkey(), initiator_identity.pubkey());
671
672        // Responder learned initiator's epoch
673        assert_eq!(responder_conn.remote_epoch(), Some(initiator_epoch));
674
675        // Initiator completes handshake
676        initiator_conn.complete_handshake(&msg2, 1300).unwrap();
677        assert_eq!(initiator_conn.handshake_state(), HandshakeState::Complete);
678
679        // Initiator learned responder's epoch
680        assert_eq!(initiator_conn.remote_epoch(), Some(responder_epoch));
681
682        // Both have sessions
683        assert!(initiator_conn.has_session());
684        assert!(responder_conn.has_session());
685
686        // Take and verify sessions work
687        let mut init_session = initiator_conn.take_session().unwrap();
688        let mut resp_session = responder_conn.take_session().unwrap();
689
690        // Encrypt/decrypt test
691        let plaintext = b"test message";
692        let ciphertext = init_session.encrypt(plaintext).unwrap();
693        let decrypted = resp_session.decrypt(&ciphertext).unwrap();
694        assert_eq!(decrypted, plaintext);
695    }
696
697    #[test]
698    fn test_connection_timing() {
699        let identity = make_peer_identity();
700        let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
701
702        assert_eq!(conn.duration(1500), 500);
703        assert_eq!(conn.idle_time(1500), 500);
704        assert!(!conn.is_timed_out(1500, 1000));
705        assert!(conn.is_timed_out(2500, 1000));
706    }
707
708    #[test]
709    fn test_connection_failure() {
710        let identity = make_peer_identity();
711        let mut conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
712
713        conn.mark_failed();
714        assert!(conn.is_failed());
715        assert!(!conn.is_in_progress());
716        assert!(!conn.is_complete());
717    }
718
719    #[test]
720    fn test_wrong_direction_errors() {
721        let identity = make_peer_identity();
722        let keypair = make_keypair();
723
724        // Outbound can't receive_handshake_init
725        let mut outbound = PeerConnection::outbound(LinkId::new(1), identity, 1000);
726        assert!(
727            outbound
728                .receive_handshake_init(keypair, make_epoch(), &[0u8; 106], 1100)
729                .is_err()
730        );
731
732        // Inbound can't start_handshake
733        let mut inbound = PeerConnection::inbound(LinkId::new(2), 1000);
734        assert!(
735            inbound
736                .start_handshake(keypair, make_epoch(), 1100)
737                .is_err()
738        );
739    }
740}