Skip to main content

slim_datapath/sync/
state.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5
6/// Tracks the state of connected peers and their connection IDs.
7#[derive(Debug, Default)]
8pub struct PeerState {
9    /// Maps peer_id → connection metadata.
10    peers: HashMap<String, PeerEntry>,
11    /// Maps conn_id → peer_id for reverse lookup (e.g., on connection drop).
12    conn_to_peer: HashMap<u64, String>,
13}
14
15#[derive(Debug, Clone)]
16pub struct PeerEntry {
17    pub conn_id: u64,
18    pub endpoint: String,
19    /// Whether we initiated the connection (true) or accepted it (false).
20    pub is_outgoing: bool,
21}
22
23impl PeerState {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Register a connected peer. Returns false if already known.
29    pub fn insert(&mut self, peer_id: String, entry: PeerEntry) -> bool {
30        if self.peers.contains_key(&peer_id) {
31            return false;
32        }
33        self.conn_to_peer.insert(entry.conn_id, peer_id.clone());
34        self.peers.insert(peer_id, entry);
35        true
36    }
37
38    /// Remove a peer by ID. Returns the entry if it existed.
39    pub fn remove(&mut self, peer_id: &str) -> Option<PeerEntry> {
40        if let Some(entry) = self.peers.remove(peer_id) {
41            self.conn_to_peer.remove(&entry.conn_id);
42            Some(entry)
43        } else {
44            None
45        }
46    }
47
48    /// Remove a peer by connection ID (e.g., on unexpected disconnect).
49    pub fn remove_by_conn(&mut self, conn_id: u64) -> Option<(String, PeerEntry)> {
50        if let Some(peer_id) = self.conn_to_peer.remove(&conn_id)
51            && let Some(entry) = self.peers.remove(&peer_id)
52        {
53            return Some((peer_id, entry));
54        }
55        None
56    }
57
58    /// Check if a peer is already connected (by peer ID).
59    pub fn contains(&self, peer_id: &str) -> bool {
60        self.peers.contains_key(peer_id)
61    }
62
63    /// Get the connection ID for a peer.
64    pub fn conn_id(&self, peer_id: &str) -> Option<u64> {
65        self.peers.get(peer_id).map(|e| e.conn_id)
66    }
67
68    /// Get all peer connection IDs.
69    pub fn all_conn_ids(&self) -> Vec<u64> {
70        self.peers.values().map(|e| e.conn_id).collect()
71    }
72
73    /// Look up a peer_id by connection ID.
74    pub fn peer_id_for_conn(&self, conn_id: u64) -> Option<&str> {
75        self.conn_to_peer.get(&conn_id).map(|s| s.as_str())
76    }
77
78    /// Number of connected peers.
79    pub fn len(&self) -> usize {
80        self.peers.len()
81    }
82
83    /// Whether there are no connected peers.
84    pub fn is_empty(&self) -> bool {
85        self.peers.is_empty()
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_insert_and_lookup() {
95        let mut state = PeerState::new();
96        let entry = PeerEntry {
97            conn_id: 42,
98            endpoint: "peer-1:8080".to_string(),
99            is_outgoing: true,
100        };
101        assert!(state.insert("peer-1".to_string(), entry));
102        assert!(state.contains("peer-1"));
103        assert_eq!(state.conn_id("peer-1"), Some(42));
104        assert_eq!(state.peer_id_for_conn(42), Some("peer-1"));
105        assert_eq!(state.len(), 1);
106    }
107
108    #[test]
109    fn test_duplicate_insert_rejected() {
110        let mut state = PeerState::new();
111        let entry = PeerEntry {
112            conn_id: 42,
113            endpoint: "peer-1:8080".to_string(),
114            is_outgoing: true,
115        };
116        assert!(state.insert("peer-1".to_string(), entry.clone()));
117        assert!(!state.insert("peer-1".to_string(), entry));
118        assert_eq!(state.len(), 1);
119    }
120
121    #[test]
122    fn test_remove_by_id() {
123        let mut state = PeerState::new();
124        let entry = PeerEntry {
125            conn_id: 42,
126            endpoint: "peer-1:8080".to_string(),
127            is_outgoing: true,
128        };
129        state.insert("peer-1".to_string(), entry);
130        let removed = state.remove("peer-1").unwrap();
131        assert_eq!(removed.conn_id, 42);
132        assert!(!state.contains("peer-1"));
133        assert_eq!(state.peer_id_for_conn(42), None);
134    }
135
136    #[test]
137    fn test_remove_by_conn() {
138        let mut state = PeerState::new();
139        let entry = PeerEntry {
140            conn_id: 42,
141            endpoint: "peer-1:8080".to_string(),
142            is_outgoing: true,
143        };
144        state.insert("peer-1".to_string(), entry);
145        let (peer_id, removed) = state.remove_by_conn(42).unwrap();
146        assert_eq!(peer_id, "peer-1");
147        assert_eq!(removed.conn_id, 42);
148        assert!(state.is_empty());
149    }
150}