1use async_trait::async_trait;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct PeerId(pub String);
13
14impl PeerId {
15 pub fn new(id: impl Into<String>) -> Self {
16 Self(id.into())
17 }
18}
19
20impl std::fmt::Display for PeerId {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 write!(f, "{}", self.0)
23 }
24}
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum PeerState {
29 Disconnected,
30 Connecting,
31 Connected,
32}
33
34#[derive(Clone, Debug)]
36pub struct Peer {
37 pub id: PeerId,
38 pub name: String,
39 pub state: PeerState,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
44pub enum Message {
45 Hello {
47 replica_id: String,
48 user_name: String,
49 },
50 SyncRequest { document_id: String, version: u64 },
52 SyncResponse {
54 document_id: String,
55 deltas: Vec<Vec<u8>>,
56 version: u64,
57 },
58 Update {
60 document_id: String,
61 delta: Vec<u8>,
62 version: u64,
63 },
64 Presence {
66 user_id: String,
67 document_id: String,
68 cursor_pos: Option<usize>,
69 },
70 Ack { message_id: u64 },
72 Ping,
74 Pong,
76}
77
78#[derive(Clone, Debug)]
80pub enum NetworkError {
81 ConnectionFailed(String),
82 PeerNotFound(String),
83 SendFailed(String),
84 Disconnected,
85}
86
87impl std::fmt::Display for NetworkError {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 NetworkError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
91 NetworkError::PeerNotFound(id) => write!(f, "Peer not found: {}", id),
92 NetworkError::SendFailed(e) => write!(f, "Send failed: {}", e),
93 NetworkError::Disconnected => write!(f, "Disconnected"),
94 }
95 }
96}
97
98impl std::error::Error for NetworkError {}
99
100#[async_trait]
102pub trait NetworkTransport: Send + Sync + 'static {
103 async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
105
106 async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
108
109 async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError>;
111
112 async fn broadcast(&self, message: Message) -> Result<(), NetworkError>;
114
115 async fn connected_peers(&self) -> Vec<Peer>;
117
118 fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)>;
120}
121
122type SharedMessageReceiver = Arc<RwLock<Option<mpsc::Receiver<(PeerId, Message)>>>>;
124type SharedOutgoing = Arc<RwLock<HashMap<PeerId, mpsc::Sender<(PeerId, Message)>>>>;
126
127pub struct MemoryTransport {
129 local_id: PeerId,
130 peers: Arc<RwLock<HashMap<PeerId, Peer>>>,
131 message_tx: mpsc::Sender<(PeerId, Message)>,
132 message_rx: SharedMessageReceiver,
133 outgoing: SharedOutgoing,
134}
135
136impl MemoryTransport {
137 pub fn new(local_id: PeerId) -> Self {
138 let (tx, rx) = mpsc::channel(100);
139 Self {
140 local_id,
141 peers: Arc::new(RwLock::new(HashMap::new())),
142 message_tx: tx,
143 message_rx: Arc::new(RwLock::new(Some(rx))),
144 outgoing: Arc::new(RwLock::new(HashMap::new())),
145 }
146 }
147
148 pub fn local_id(&self) -> &PeerId {
149 &self.local_id
150 }
151
152 pub fn connect_to(&self, other: &MemoryTransport) {
154 self.peers.write().insert(
156 other.local_id.clone(),
157 Peer {
158 id: other.local_id.clone(),
159 name: other.local_id.0.clone(),
160 state: PeerState::Connected,
161 },
162 );
163
164 self.outgoing
166 .write()
167 .insert(other.local_id.clone(), other.message_tx.clone());
168
169 other.peers.write().insert(
171 self.local_id.clone(),
172 Peer {
173 id: self.local_id.clone(),
174 name: self.local_id.0.clone(),
175 state: PeerState::Connected,
176 },
177 );
178
179 other
181 .outgoing
182 .write()
183 .insert(self.local_id.clone(), self.message_tx.clone());
184 }
185}
186
187#[async_trait]
188impl NetworkTransport for MemoryTransport {
189 async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
190 self.peers.write().insert(
191 peer_id.clone(),
192 Peer {
193 id: peer_id.clone(),
194 name: peer_id.0.clone(),
195 state: PeerState::Connected,
196 },
197 );
198 Ok(())
199 }
200
201 async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
202 self.peers.write().remove(peer_id);
203 self.outgoing.write().remove(peer_id);
204 Ok(())
205 }
206
207 async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError> {
208 let tx = {
209 let outgoing = self.outgoing.read();
210 outgoing.get(peer_id).cloned()
211 };
212
213 if let Some(tx) = tx {
214 tx.send((self.local_id.clone(), message))
215 .await
216 .map_err(|e| NetworkError::SendFailed(e.to_string()))?;
217 Ok(())
218 } else {
219 Err(NetworkError::PeerNotFound(peer_id.to_string()))
220 }
221 }
222
223 async fn broadcast(&self, message: Message) -> Result<(), NetworkError> {
224 let senders: Vec<_> = {
225 let outgoing = self.outgoing.read();
226 outgoing.values().cloned().collect()
227 };
228
229 for tx in senders {
230 let _ = tx.send((self.local_id.clone(), message.clone())).await;
231 }
232 Ok(())
233 }
234
235 async fn connected_peers(&self) -> Vec<Peer> {
236 self.peers.read().values().cloned().collect()
237 }
238
239 fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)> {
240 self.message_rx
241 .write()
242 .take()
243 .expect("subscribe can only be called once")
244 }
245}
246
247pub fn create_network(count: usize) -> Vec<MemoryTransport> {
249 let transports: Vec<_> = (0..count)
250 .map(|i| MemoryTransport::new(PeerId::new(format!("peer-{}", i))))
251 .collect();
252
253 for i in 0..count {
255 for j in (i + 1)..count {
256 transports[i].connect_to(&transports[j]);
257 }
258 }
259
260 transports
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[tokio::test]
268 async fn test_memory_transport() {
269 let transport1 = MemoryTransport::new(PeerId::new("peer-1"));
270 let transport2 = MemoryTransport::new(PeerId::new("peer-2"));
271
272 transport1.connect_to(&transport2);
273
274 let peers1 = transport1.connected_peers().await;
275 let peers2 = transport2.connected_peers().await;
276
277 assert_eq!(peers1.len(), 1);
278 assert_eq!(peers2.len(), 1);
279 }
280
281 #[tokio::test]
282 async fn test_network_creation() {
283 let network = create_network(3);
284 assert_eq!(network.len(), 3);
285
286 for transport in &network {
288 let peers = transport.connected_peers().await;
289 assert_eq!(peers.len(), 2);
290 }
291 }
292}