Skip to main content

mdcs_sdk/
network.rs

1//! Network transport abstractions for MDCS synchronization.
2
3use async_trait::async_trait;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10/// Unique identifier for a peer.
11#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct PeerId(pub String);
13
14impl PeerId {
15    pub fn new(id: impl Into<String>) -> Self {
16        Self(id.into())
17    }
18}
19
20impl std::fmt::Display for PeerId {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(f, "{}", self.0)
23    }
24}
25
26/// Peer connection state.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum PeerState {
29    Disconnected,
30    Connecting,
31    Connected,
32}
33
34/// Information about a connected peer.
35#[derive(Clone, Debug)]
36pub struct Peer {
37    pub id: PeerId,
38    pub name: String,
39    pub state: PeerState,
40}
41
42/// Messages exchanged between peers.
43#[derive(Clone, Debug, Serialize, Deserialize)]
44pub enum Message {
45    /// Hello/handshake message.
46    Hello {
47        replica_id: String,
48        user_name: String,
49    },
50    /// Request sync for a document.
51    SyncRequest { document_id: String, version: u64 },
52    /// Response with deltas.
53    SyncResponse {
54        document_id: String,
55        deltas: Vec<Vec<u8>>,
56        version: u64,
57    },
58    /// Incremental update.
59    Update {
60        document_id: String,
61        delta: Vec<u8>,
62        version: u64,
63    },
64    /// Presence update.
65    Presence {
66        user_id: String,
67        document_id: String,
68        cursor_pos: Option<usize>,
69    },
70    /// Acknowledgment.
71    Ack { message_id: u64 },
72    /// Ping for keepalive.
73    Ping,
74    /// Pong response.
75    Pong,
76}
77
78/// Network error type.
79#[derive(Clone, Debug)]
80pub enum NetworkError {
81    ConnectionFailed(String),
82    PeerNotFound(String),
83    SendFailed(String),
84    Disconnected,
85}
86
87impl std::fmt::Display for NetworkError {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            NetworkError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
91            NetworkError::PeerNotFound(id) => write!(f, "Peer not found: {}", id),
92            NetworkError::SendFailed(e) => write!(f, "Send failed: {}", e),
93            NetworkError::Disconnected => write!(f, "Disconnected"),
94        }
95    }
96}
97
98impl std::error::Error for NetworkError {}
99
100/// Abstract network transport trait.
101#[async_trait]
102pub trait NetworkTransport: Send + Sync + 'static {
103    /// Connect to a peer.
104    async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
105
106    /// Disconnect from a peer.
107    async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
108
109    /// Send a message to a specific peer.
110    async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError>;
111
112    /// Broadcast a message to all connected peers.
113    async fn broadcast(&self, message: Message) -> Result<(), NetworkError>;
114
115    /// Get list of connected peers.
116    async fn connected_peers(&self) -> Vec<Peer>;
117
118    /// Subscribe to incoming messages.
119    fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)>;
120}
121
122/// Type alias for the message receiver shared across threads.
123type SharedMessageReceiver = Arc<RwLock<Option<mpsc::Receiver<(PeerId, Message)>>>>;
124/// Type alias for the outgoing message senders shared across threads.
125type SharedOutgoing = Arc<RwLock<HashMap<PeerId, mpsc::Sender<(PeerId, Message)>>>>;
126
127/// In-memory transport for testing and simulation.
128pub struct MemoryTransport {
129    local_id: PeerId,
130    peers: Arc<RwLock<HashMap<PeerId, Peer>>>,
131    message_tx: mpsc::Sender<(PeerId, Message)>,
132    message_rx: SharedMessageReceiver,
133    outgoing: SharedOutgoing,
134}
135
136impl MemoryTransport {
137    pub fn new(local_id: PeerId) -> Self {
138        let (tx, rx) = mpsc::channel(100);
139        Self {
140            local_id,
141            peers: Arc::new(RwLock::new(HashMap::new())),
142            message_tx: tx,
143            message_rx: Arc::new(RwLock::new(Some(rx))),
144            outgoing: Arc::new(RwLock::new(HashMap::new())),
145        }
146    }
147
148    pub fn local_id(&self) -> &PeerId {
149        &self.local_id
150    }
151
152    /// Connect two memory transports together (for testing).
153    pub fn connect_to(&self, other: &MemoryTransport) {
154        // Add peer to our list
155        self.peers.write().insert(
156            other.local_id.clone(),
157            Peer {
158                id: other.local_id.clone(),
159                name: other.local_id.0.clone(),
160                state: PeerState::Connected,
161            },
162        );
163
164        // Set up channel to send to other
165        self.outgoing
166            .write()
167            .insert(other.local_id.clone(), other.message_tx.clone());
168
169        // Add us to other's peer list
170        other.peers.write().insert(
171            self.local_id.clone(),
172            Peer {
173                id: self.local_id.clone(),
174                name: self.local_id.0.clone(),
175                state: PeerState::Connected,
176            },
177        );
178
179        // Set up channel for other to send to us
180        other
181            .outgoing
182            .write()
183            .insert(self.local_id.clone(), self.message_tx.clone());
184    }
185}
186
187#[async_trait]
188impl NetworkTransport for MemoryTransport {
189    async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
190        self.peers.write().insert(
191            peer_id.clone(),
192            Peer {
193                id: peer_id.clone(),
194                name: peer_id.0.clone(),
195                state: PeerState::Connected,
196            },
197        );
198        Ok(())
199    }
200
201    async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
202        self.peers.write().remove(peer_id);
203        self.outgoing.write().remove(peer_id);
204        Ok(())
205    }
206
207    async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError> {
208        let tx = {
209            let outgoing = self.outgoing.read();
210            outgoing.get(peer_id).cloned()
211        };
212
213        if let Some(tx) = tx {
214            tx.send((self.local_id.clone(), message))
215                .await
216                .map_err(|e| NetworkError::SendFailed(e.to_string()))?;
217            Ok(())
218        } else {
219            Err(NetworkError::PeerNotFound(peer_id.to_string()))
220        }
221    }
222
223    async fn broadcast(&self, message: Message) -> Result<(), NetworkError> {
224        let senders: Vec<_> = {
225            let outgoing = self.outgoing.read();
226            outgoing.values().cloned().collect()
227        };
228
229        for tx in senders {
230            let _ = tx.send((self.local_id.clone(), message.clone())).await;
231        }
232        Ok(())
233    }
234
235    async fn connected_peers(&self) -> Vec<Peer> {
236        self.peers.read().values().cloned().collect()
237    }
238
239    fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)> {
240        self.message_rx
241            .write()
242            .take()
243            .expect("subscribe can only be called once")
244    }
245}
246
247/// Create a network of connected memory transports for testing.
248pub fn create_network(count: usize) -> Vec<MemoryTransport> {
249    let transports: Vec<_> = (0..count)
250        .map(|i| MemoryTransport::new(PeerId::new(format!("peer-{}", i))))
251        .collect();
252
253    // Connect all peers to each other
254    for i in 0..count {
255        for j in (i + 1)..count {
256            transports[i].connect_to(&transports[j]);
257        }
258    }
259
260    transports
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[tokio::test]
268    async fn test_memory_transport() {
269        let transport1 = MemoryTransport::new(PeerId::new("peer-1"));
270        let transport2 = MemoryTransport::new(PeerId::new("peer-2"));
271
272        transport1.connect_to(&transport2);
273
274        let peers1 = transport1.connected_peers().await;
275        let peers2 = transport2.connected_peers().await;
276
277        assert_eq!(peers1.len(), 1);
278        assert_eq!(peers2.len(), 1);
279    }
280
281    #[tokio::test]
282    async fn test_network_creation() {
283        let network = create_network(3);
284        assert_eq!(network.len(), 3);
285
286        // Each peer should be connected to 2 others
287        for transport in &network {
288            let peers = transport.connected_peers().await;
289            assert_eq!(peers.len(), 2);
290        }
291    }
292}