1use async_trait::async_trait;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct PeerId(pub String);
13
14impl PeerId {
15 pub fn new(id: impl Into<String>) -> Self {
17 Self(id.into())
18 }
19}
20
21impl std::fmt::Display for PeerId {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "{}", self.0)
24 }
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
29pub enum PeerState {
30 Disconnected,
31 Connecting,
32 Connected,
33}
34
35#[derive(Clone, Debug)]
37pub struct Peer {
38 pub id: PeerId,
39 pub name: String,
40 pub state: PeerState,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
45pub enum Message {
46 Hello {
48 replica_id: String,
49 user_name: String,
50 },
51 SyncRequest { document_id: String, version: u64 },
53 SyncResponse {
55 document_id: String,
56 deltas: Vec<Vec<u8>>,
57 version: u64,
58 },
59 Update {
61 document_id: String,
62 delta: Vec<u8>,
63 version: u64,
64 },
65 Presence {
67 user_id: String,
68 document_id: String,
69 cursor_pos: Option<usize>,
70 },
71 Ack { message_id: u64 },
73 Ping,
75 Pong,
77}
78
79#[derive(Clone, Debug)]
81pub enum NetworkError {
82 ConnectionFailed(String),
84 PeerNotFound(String),
86 SendFailed(String),
88 Disconnected,
90}
91
92impl std::fmt::Display for NetworkError {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 NetworkError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
96 NetworkError::PeerNotFound(id) => write!(f, "Peer not found: {}", id),
97 NetworkError::SendFailed(e) => write!(f, "Send failed: {}", e),
98 NetworkError::Disconnected => write!(f, "Disconnected"),
99 }
100 }
101}
102
103impl std::error::Error for NetworkError {}
104
105#[async_trait]
107pub trait NetworkTransport: Send + Sync + 'static {
108 async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
110
111 async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
113
114 async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError>;
116
117 async fn broadcast(&self, message: Message) -> Result<(), NetworkError>;
119
120 async fn connected_peers(&self) -> Vec<Peer>;
122
123 fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)>;
125}
126
127type SharedMessageReceiver = Arc<RwLock<Option<mpsc::Receiver<(PeerId, Message)>>>>;
129type SharedOutgoing = Arc<RwLock<HashMap<PeerId, mpsc::Sender<(PeerId, Message)>>>>;
131
132pub struct MemoryTransport {
134 local_id: PeerId,
135 peers: Arc<RwLock<HashMap<PeerId, Peer>>>,
136 message_tx: mpsc::Sender<(PeerId, Message)>,
137 message_rx: SharedMessageReceiver,
138 outgoing: SharedOutgoing,
139}
140
141impl MemoryTransport {
142 pub fn new(local_id: PeerId) -> Self {
144 let (tx, rx) = mpsc::channel(100);
145 Self {
146 local_id,
147 peers: Arc::new(RwLock::new(HashMap::new())),
148 message_tx: tx,
149 message_rx: Arc::new(RwLock::new(Some(rx))),
150 outgoing: Arc::new(RwLock::new(HashMap::new())),
151 }
152 }
153
154 pub fn local_id(&self) -> &PeerId {
156 &self.local_id
157 }
158
159 pub fn connect_to(&self, other: &MemoryTransport) {
161 self.peers.write().insert(
163 other.local_id.clone(),
164 Peer {
165 id: other.local_id.clone(),
166 name: other.local_id.0.clone(),
167 state: PeerState::Connected,
168 },
169 );
170
171 self.outgoing
173 .write()
174 .insert(other.local_id.clone(), other.message_tx.clone());
175
176 other.peers.write().insert(
178 self.local_id.clone(),
179 Peer {
180 id: self.local_id.clone(),
181 name: self.local_id.0.clone(),
182 state: PeerState::Connected,
183 },
184 );
185
186 other
188 .outgoing
189 .write()
190 .insert(self.local_id.clone(), self.message_tx.clone());
191 }
192}
193
194#[async_trait]
195impl NetworkTransport for MemoryTransport {
196 async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
197 self.peers.write().insert(
198 peer_id.clone(),
199 Peer {
200 id: peer_id.clone(),
201 name: peer_id.0.clone(),
202 state: PeerState::Connected,
203 },
204 );
205 Ok(())
206 }
207
208 async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
209 self.peers.write().remove(peer_id);
210 self.outgoing.write().remove(peer_id);
211 Ok(())
212 }
213
214 async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError> {
215 let tx = {
216 let outgoing = self.outgoing.read();
217 outgoing.get(peer_id).cloned()
218 };
219
220 if let Some(tx) = tx {
221 tx.send((self.local_id.clone(), message))
222 .await
223 .map_err(|e| NetworkError::SendFailed(e.to_string()))?;
224 Ok(())
225 } else {
226 Err(NetworkError::PeerNotFound(peer_id.to_string()))
227 }
228 }
229
230 async fn broadcast(&self, message: Message) -> Result<(), NetworkError> {
231 let senders: Vec<_> = {
232 let outgoing = self.outgoing.read();
233 outgoing.values().cloned().collect()
234 };
235
236 for tx in senders {
237 let _ = tx.send((self.local_id.clone(), message.clone())).await;
238 }
239 Ok(())
240 }
241
242 async fn connected_peers(&self) -> Vec<Peer> {
243 self.peers.read().values().cloned().collect()
244 }
245
246 fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)> {
247 self.message_rx
248 .write()
249 .take()
250 .expect("subscribe can only be called once")
251 }
252}
253
254pub fn create_network(count: usize) -> Vec<MemoryTransport> {
259 let transports: Vec<_> = (0..count)
260 .map(|i| MemoryTransport::new(PeerId::new(format!("peer-{}", i))))
261 .collect();
262
263 for i in 0..count {
265 for j in (i + 1)..count {
266 transports[i].connect_to(&transports[j]);
267 }
268 }
269
270 transports
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[tokio::test]
278 async fn test_memory_transport() {
279 let transport1 = MemoryTransport::new(PeerId::new("peer-1"));
280 let transport2 = MemoryTransport::new(PeerId::new("peer-2"));
281
282 transport1.connect_to(&transport2);
283
284 let peers1 = transport1.connected_peers().await;
285 let peers2 = transport2.connected_peers().await;
286
287 assert_eq!(peers1.len(), 1);
288 assert_eq!(peers2.len(), 1);
289 }
290
291 #[tokio::test]
292 async fn test_network_creation() {
293 let network = create_network(3);
294 assert_eq!(network.len(), 3);
295
296 for transport in &network {
298 let peers = transport.connected_peers().await;
299 assert_eq!(peers.len(), 2);
300 }
301 }
302}