1use crate::PeerIdentity;
8use crate::noise::{self, NoiseError, NoiseSession};
9use crate::transport::{LinkDirection, LinkId, LinkStats, TransportAddr, TransportId};
10use crate::utils::index::SessionIndex;
11use secp256k1::Keypair;
12use std::fmt;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum HandshakeState {
21 Initial,
23 SentMsg1,
25 ReceivedMsg1,
27 Complete,
29 Failed,
31}
32
33impl HandshakeState {
34 pub fn is_in_progress(&self) -> bool {
36 matches!(
37 self,
38 HandshakeState::Initial | HandshakeState::SentMsg1 | HandshakeState::ReceivedMsg1
39 )
40 }
41
42 pub fn is_complete(&self) -> bool {
44 matches!(self, HandshakeState::Complete)
45 }
46
47 pub fn is_failed(&self) -> bool {
49 matches!(self, HandshakeState::Failed)
50 }
51}
52
53impl fmt::Display for HandshakeState {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 let s = match self {
56 HandshakeState::Initial => "initial",
57 HandshakeState::SentMsg1 => "sent_msg1",
58 HandshakeState::ReceivedMsg1 => "received_msg1",
59 HandshakeState::Complete => "complete",
60 HandshakeState::Failed => "failed",
61 };
62 write!(f, "{}", s)
63 }
64}
65
66pub struct PeerConnection {
71 link_id: LinkId,
74
75 direction: LinkDirection,
77
78 handshake_state: HandshakeState,
81
82 expected_identity: Option<PeerIdentity>,
85
86 noise_handshake: Option<noise::HandshakeState>,
88
89 noise_session: Option<NoiseSession>,
91
92 started_at: u64,
95
96 last_activity: u64,
98
99 link_stats: LinkStats,
102
103 our_index: Option<SessionIndex>,
108
109 their_index: Option<SessionIndex>,
113
114 transport_id: Option<TransportId>,
116
117 source_addr: Option<TransportAddr>,
119
120 remote_epoch: Option<[u8; 8]>,
123
124 handshake_msg1: Option<Vec<u8>>,
127
128 handshake_msg2: Option<Vec<u8>>,
130
131 resend_count: u32,
133
134 next_resend_at_ms: u64,
136}
137
138impl PeerConnection {
139 pub fn outbound(
144 link_id: LinkId,
145 expected_identity: PeerIdentity,
146 current_time_ms: u64,
147 ) -> Self {
148 Self {
149 link_id,
150 direction: LinkDirection::Outbound,
151 handshake_state: HandshakeState::Initial,
152 expected_identity: Some(expected_identity),
153 noise_handshake: None,
154 noise_session: None,
155 started_at: current_time_ms,
156 last_activity: current_time_ms,
157
158 link_stats: LinkStats::new(),
159 our_index: None,
160 their_index: None,
161 transport_id: None,
162 source_addr: None,
163 remote_epoch: None,
164 handshake_msg1: None,
165 handshake_msg2: None,
166 resend_count: 0,
167 next_resend_at_ms: 0,
168 }
169 }
170
171 pub fn inbound(link_id: LinkId, current_time_ms: u64) -> Self {
176 Self {
177 link_id,
178 direction: LinkDirection::Inbound,
179 handshake_state: HandshakeState::Initial,
180 expected_identity: None,
181 noise_handshake: None,
182 noise_session: None,
183 started_at: current_time_ms,
184 last_activity: current_time_ms,
185
186 link_stats: LinkStats::new(),
187 our_index: None,
188 their_index: None,
189 transport_id: None,
190 source_addr: None,
191 remote_epoch: None,
192 handshake_msg1: None,
193 handshake_msg2: None,
194 resend_count: 0,
195 next_resend_at_ms: 0,
196 }
197 }
198
199 pub fn inbound_with_transport(
203 link_id: LinkId,
204 transport_id: TransportId,
205 source_addr: TransportAddr,
206 current_time_ms: u64,
207 ) -> Self {
208 Self {
209 link_id,
210 direction: LinkDirection::Inbound,
211 handshake_state: HandshakeState::Initial,
212 expected_identity: None,
213 noise_handshake: None,
214 noise_session: None,
215 started_at: current_time_ms,
216 last_activity: current_time_ms,
217
218 link_stats: LinkStats::new(),
219 our_index: None,
220 their_index: None,
221 transport_id: Some(transport_id),
222 source_addr: Some(source_addr),
223 remote_epoch: None,
224 handshake_msg1: None,
225 handshake_msg2: None,
226 resend_count: 0,
227 next_resend_at_ms: 0,
228 }
229 }
230
231 pub fn link_id(&self) -> LinkId {
235 self.link_id
236 }
237
238 pub fn direction(&self) -> LinkDirection {
240 self.direction
241 }
242
243 pub fn handshake_state(&self) -> HandshakeState {
245 self.handshake_state
246 }
247
248 pub fn expected_identity(&self) -> Option<&PeerIdentity> {
250 self.expected_identity.as_ref()
251 }
252
253 pub fn is_outbound(&self) -> bool {
255 self.direction == LinkDirection::Outbound
256 }
257
258 pub fn is_inbound(&self) -> bool {
260 self.direction == LinkDirection::Inbound
261 }
262
263 pub fn is_in_progress(&self) -> bool {
265 self.handshake_state.is_in_progress()
266 }
267
268 pub fn is_complete(&self) -> bool {
270 self.handshake_state.is_complete()
271 }
272
273 pub fn is_failed(&self) -> bool {
275 self.handshake_state.is_failed()
276 }
277
278 pub fn started_at(&self) -> u64 {
280 self.started_at
281 }
282
283 pub fn last_activity(&self) -> u64 {
285 self.last_activity
286 }
287
288 pub fn duration(&self, current_time_ms: u64) -> u64 {
290 current_time_ms.saturating_sub(self.started_at)
291 }
292
293 pub fn idle_time(&self, current_time_ms: u64) -> u64 {
295 current_time_ms.saturating_sub(self.last_activity)
296 }
297
298 pub fn link_stats(&self) -> &LinkStats {
300 &self.link_stats
301 }
302
303 pub fn link_stats_mut(&mut self) -> &mut LinkStats {
305 &mut self.link_stats
306 }
307
308 pub fn our_index(&self) -> Option<SessionIndex> {
312 self.our_index
313 }
314
315 pub fn set_our_index(&mut self, index: SessionIndex) {
317 self.our_index = Some(index);
318 }
319
320 pub fn their_index(&self) -> Option<SessionIndex> {
322 self.their_index
323 }
324
325 pub fn set_their_index(&mut self, index: SessionIndex) {
327 self.their_index = Some(index);
328 }
329
330 pub fn transport_id(&self) -> Option<TransportId> {
332 self.transport_id
333 }
334
335 pub fn set_transport_id(&mut self, id: TransportId) {
337 self.transport_id = Some(id);
338 }
339
340 pub fn source_addr(&self) -> Option<&TransportAddr> {
342 self.source_addr.as_ref()
343 }
344
345 pub fn set_source_addr(&mut self, addr: TransportAddr) {
347 self.source_addr = Some(addr);
348 }
349
350 pub fn remote_epoch(&self) -> Option<[u8; 8]> {
354 self.remote_epoch
355 }
356
357 pub fn set_handshake_msg1(&mut self, msg1: Vec<u8>, first_resend_at_ms: u64) {
361 self.handshake_msg1 = Some(msg1);
362 self.resend_count = 0;
363 self.next_resend_at_ms = first_resend_at_ms;
364 }
365
366 pub fn set_handshake_msg2(&mut self, msg2: Vec<u8>) {
368 self.handshake_msg2 = Some(msg2);
369 }
370
371 pub fn handshake_msg1(&self) -> Option<&[u8]> {
373 self.handshake_msg1.as_deref()
374 }
375
376 pub fn handshake_msg2(&self) -> Option<&[u8]> {
378 self.handshake_msg2.as_deref()
379 }
380
381 pub fn resend_count(&self) -> u32 {
383 self.resend_count
384 }
385
386 pub fn next_resend_at_ms(&self) -> u64 {
388 self.next_resend_at_ms
389 }
390
391 pub fn record_resend(&mut self, next_resend_at_ms: u64) {
393 self.resend_count += 1;
394 self.next_resend_at_ms = next_resend_at_ms;
395 }
396
397 pub fn start_handshake(
404 &mut self,
405 our_keypair: Keypair,
406 epoch: [u8; 8],
407 current_time_ms: u64,
408 ) -> Result<Vec<u8>, NoiseError> {
409 if self.direction != LinkDirection::Outbound {
410 return Err(NoiseError::WrongState {
411 expected: "outbound connection".to_string(),
412 got: "inbound connection".to_string(),
413 });
414 }
415
416 if self.handshake_state != HandshakeState::Initial {
417 return Err(NoiseError::WrongState {
418 expected: "initial state".to_string(),
419 got: self.handshake_state.to_string(),
420 });
421 }
422
423 let remote_static = self
424 .expected_identity
425 .as_ref()
426 .expect("outbound must have expected identity")
427 .pubkey_full();
428
429 let mut hs = noise::HandshakeState::new_initiator(our_keypair, remote_static);
430 hs.set_local_epoch(epoch);
431 let msg1 = hs.write_message_1()?;
432
433 self.noise_handshake = Some(hs);
434 self.handshake_state = HandshakeState::SentMsg1;
435 self.last_activity = current_time_ms;
436
437 Ok(msg1)
438 }
439
440 pub fn receive_handshake_init(
445 &mut self,
446 our_keypair: Keypair,
447 epoch: [u8; 8],
448 message: &[u8],
449 current_time_ms: u64,
450 ) -> Result<Vec<u8>, NoiseError> {
451 if self.direction != LinkDirection::Inbound {
452 return Err(NoiseError::WrongState {
453 expected: "inbound connection".to_string(),
454 got: "outbound connection".to_string(),
455 });
456 }
457
458 if self.handshake_state != HandshakeState::Initial {
459 return Err(NoiseError::WrongState {
460 expected: "initial state".to_string(),
461 got: self.handshake_state.to_string(),
462 });
463 }
464
465 let mut hs = noise::HandshakeState::new_responder(our_keypair);
466 hs.set_local_epoch(epoch);
467
468 hs.read_message_1(message)?;
470
471 let remote_static = *hs
473 .remote_static()
474 .expect("remote static available after msg1");
475 self.expected_identity = Some(PeerIdentity::from_pubkey_full(remote_static));
476
477 self.remote_epoch = hs.remote_epoch();
479
480 let msg2 = hs.write_message_2()?;
482
483 let session = hs.into_session()?;
485 self.noise_session = Some(session);
486 self.handshake_state = HandshakeState::Complete;
487 self.last_activity = current_time_ms;
488
489 Ok(msg2)
490 }
491
492 pub fn complete_handshake(
496 &mut self,
497 message: &[u8],
498 current_time_ms: u64,
499 ) -> Result<(), NoiseError> {
500 if self.handshake_state != HandshakeState::SentMsg1 {
501 return Err(NoiseError::WrongState {
502 expected: "sent_msg1 state".to_string(),
503 got: self.handshake_state.to_string(),
504 });
505 }
506
507 let mut hs = self
508 .noise_handshake
509 .take()
510 .expect("noise handshake must exist in SentMsg1 state");
511
512 hs.read_message_2(message)?;
513
514 self.remote_epoch = hs.remote_epoch();
516
517 let session = hs.into_session()?;
518 self.noise_session = Some(session);
519 self.handshake_state = HandshakeState::Complete;
520 self.last_activity = current_time_ms;
521
522 Ok(())
523 }
524
525 pub fn take_session(&mut self) -> Option<NoiseSession> {
530 if self.handshake_state == HandshakeState::Complete {
531 self.noise_session.take()
532 } else {
533 None
534 }
535 }
536
537 pub fn has_session(&self) -> bool {
539 self.handshake_state == HandshakeState::Complete && self.noise_session.is_some()
540 }
541
542 pub fn mark_failed(&mut self) {
546 self.handshake_state = HandshakeState::Failed;
547 self.noise_handshake = None;
548 }
549
550 pub fn touch(&mut self, current_time_ms: u64) {
552 self.last_activity = current_time_ms;
553 }
554
555 pub fn is_timed_out(&self, current_time_ms: u64, timeout_ms: u64) -> bool {
559 self.idle_time(current_time_ms) > timeout_ms
560 }
561}
562
563impl fmt::Debug for PeerConnection {
564 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565 f.debug_struct("PeerConnection")
566 .field("link_id", &self.link_id)
567 .field("direction", &self.direction)
568 .field("handshake_state", &self.handshake_state)
569 .field("expected_identity", &self.expected_identity)
570 .field("has_noise_handshake", &self.noise_handshake.is_some())
571 .field("has_noise_session", &self.noise_session.is_some())
572 .field("our_index", &self.our_index)
573 .field("their_index", &self.their_index)
574 .field("transport_id", &self.transport_id)
575 .field("started_at", &self.started_at)
576 .field("last_activity", &self.last_activity)
577 .finish()
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use crate::Identity;
585 use rand::Rng;
586
587 fn make_peer_identity() -> PeerIdentity {
588 let identity = Identity::generate();
589 PeerIdentity::from_pubkey(identity.pubkey())
590 }
591
592 fn make_keypair() -> Keypair {
593 let identity = Identity::generate();
594 identity.keypair()
595 }
596
597 fn make_epoch() -> [u8; 8] {
598 let mut epoch = [0u8; 8];
599 rand::rng().fill_bytes(&mut epoch);
600 epoch
601 }
602
603 #[test]
604 fn test_handshake_state_properties() {
605 assert!(HandshakeState::Initial.is_in_progress());
606 assert!(HandshakeState::SentMsg1.is_in_progress());
607 assert!(HandshakeState::ReceivedMsg1.is_in_progress());
608 assert!(!HandshakeState::Complete.is_in_progress());
609 assert!(!HandshakeState::Failed.is_in_progress());
610
611 assert!(HandshakeState::Complete.is_complete());
612 assert!(HandshakeState::Failed.is_failed());
613 }
614
615 #[test]
616 fn test_outbound_connection() {
617 let identity = make_peer_identity();
618 let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
619
620 assert!(conn.is_outbound());
621 assert!(!conn.is_inbound());
622 assert_eq!(conn.handshake_state(), HandshakeState::Initial);
623 assert!(conn.expected_identity().is_some());
624 assert_eq!(conn.started_at(), 1000);
625 }
626
627 #[test]
628 fn test_inbound_connection() {
629 let conn = PeerConnection::inbound(LinkId::new(2), 2000);
630
631 assert!(conn.is_inbound());
632 assert!(!conn.is_outbound());
633 assert_eq!(conn.handshake_state(), HandshakeState::Initial);
634 assert!(conn.expected_identity().is_none());
635 assert_eq!(conn.started_at(), 2000);
636 }
637
638 #[test]
639 fn test_full_handshake_flow() {
640 let initiator_identity = Identity::generate();
642 let responder_identity = Identity::generate();
643
644 let initiator_keypair = initiator_identity.keypair();
645 let responder_keypair = responder_identity.keypair();
646 let initiator_epoch = make_epoch();
647 let responder_epoch = make_epoch();
648
649 let responder_peer_id = PeerIdentity::from_pubkey_full(responder_identity.pubkey_full());
651
652 let mut initiator_conn = PeerConnection::outbound(LinkId::new(1), responder_peer_id, 1000);
654 let mut responder_conn = PeerConnection::inbound(LinkId::new(2), 1000);
655
656 let msg1 = initiator_conn
658 .start_handshake(initiator_keypair, initiator_epoch, 1100)
659 .unwrap();
660 assert_eq!(initiator_conn.handshake_state(), HandshakeState::SentMsg1);
661
662 let msg2 = responder_conn
664 .receive_handshake_init(responder_keypair, responder_epoch, &msg1, 1200)
665 .unwrap();
666 assert_eq!(responder_conn.handshake_state(), HandshakeState::Complete);
667
668 let discovered = responder_conn.expected_identity().unwrap();
670 assert_eq!(discovered.pubkey(), initiator_identity.pubkey());
671
672 assert_eq!(responder_conn.remote_epoch(), Some(initiator_epoch));
674
675 initiator_conn.complete_handshake(&msg2, 1300).unwrap();
677 assert_eq!(initiator_conn.handshake_state(), HandshakeState::Complete);
678
679 assert_eq!(initiator_conn.remote_epoch(), Some(responder_epoch));
681
682 assert!(initiator_conn.has_session());
684 assert!(responder_conn.has_session());
685
686 let mut init_session = initiator_conn.take_session().unwrap();
688 let mut resp_session = responder_conn.take_session().unwrap();
689
690 let plaintext = b"test message";
692 let ciphertext = init_session.encrypt(plaintext).unwrap();
693 let decrypted = resp_session.decrypt(&ciphertext).unwrap();
694 assert_eq!(decrypted, plaintext);
695 }
696
697 #[test]
698 fn test_connection_timing() {
699 let identity = make_peer_identity();
700 let conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
701
702 assert_eq!(conn.duration(1500), 500);
703 assert_eq!(conn.idle_time(1500), 500);
704 assert!(!conn.is_timed_out(1500, 1000));
705 assert!(conn.is_timed_out(2500, 1000));
706 }
707
708 #[test]
709 fn test_connection_failure() {
710 let identity = make_peer_identity();
711 let mut conn = PeerConnection::outbound(LinkId::new(1), identity, 1000);
712
713 conn.mark_failed();
714 assert!(conn.is_failed());
715 assert!(!conn.is_in_progress());
716 assert!(!conn.is_complete());
717 }
718
719 #[test]
720 fn test_wrong_direction_errors() {
721 let identity = make_peer_identity();
722 let keypair = make_keypair();
723
724 let mut outbound = PeerConnection::outbound(LinkId::new(1), identity, 1000);
726 assert!(
727 outbound
728 .receive_handshake_init(keypair, make_epoch(), &[0u8; 106], 1100)
729 .is_err()
730 );
731
732 let mut inbound = PeerConnection::inbound(LinkId::new(2), 1000);
734 assert!(
735 inbound
736 .start_handshake(keypair, make_epoch(), 1100)
737 .is_err()
738 );
739 }
740}