Skip to main content

mdcs_sdk/
network.rs

1//! Network transport abstractions for MDCS synchronization.
2
3use async_trait::async_trait;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10/// Unique identifier for a peer.
11#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct PeerId(pub String);
13
14impl PeerId {
15    /// Construct a new peer identifier from any string-like input.
16    pub fn new(id: impl Into<String>) -> Self {
17        Self(id.into())
18    }
19}
20
21impl std::fmt::Display for PeerId {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "{}", self.0)
24    }
25}
26
27/// Peer connection state.
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub enum PeerState {
30    Disconnected,
31    Connecting,
32    Connected,
33}
34
35/// Information about a connected peer.
36#[derive(Clone, Debug)]
37pub struct Peer {
38    pub id: PeerId,
39    pub name: String,
40    pub state: PeerState,
41}
42
43/// Messages exchanged between peers.
44#[derive(Clone, Debug, Serialize, Deserialize)]
45pub enum Message {
46    /// Hello/handshake message.
47    Hello {
48        replica_id: String,
49        user_name: String,
50    },
51    /// Request sync for a document.
52    SyncRequest { document_id: String, version: u64 },
53    /// Response with deltas.
54    SyncResponse {
55        document_id: String,
56        deltas: Vec<Vec<u8>>,
57        version: u64,
58    },
59    /// Incremental update.
60    Update {
61        document_id: String,
62        delta: Vec<u8>,
63        version: u64,
64    },
65    /// Presence update.
66    Presence {
67        user_id: String,
68        document_id: String,
69        cursor_pos: Option<usize>,
70    },
71    /// Acknowledgment.
72    Ack { message_id: u64 },
73    /// Ping for keepalive.
74    Ping,
75    /// Pong response.
76    Pong,
77}
78
79/// Network error type.
80#[derive(Clone, Debug)]
81pub enum NetworkError {
82    /// Failed to establish a connection.
83    ConnectionFailed(String),
84    /// Target peer is not known by this transport.
85    PeerNotFound(String),
86    /// Message send failed.
87    SendFailed(String),
88    /// Transport is disconnected.
89    Disconnected,
90}
91
92impl std::fmt::Display for NetworkError {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            NetworkError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
96            NetworkError::PeerNotFound(id) => write!(f, "Peer not found: {}", id),
97            NetworkError::SendFailed(e) => write!(f, "Send failed: {}", e),
98            NetworkError::Disconnected => write!(f, "Disconnected"),
99        }
100    }
101}
102
103impl std::error::Error for NetworkError {}
104
105/// Abstract network transport trait.
106#[async_trait]
107pub trait NetworkTransport: Send + Sync + 'static {
108    /// Connect to a peer.
109    async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
110
111    /// Disconnect from a peer.
112    async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError>;
113
114    /// Send a message to a specific peer.
115    async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError>;
116
117    /// Broadcast a message to all connected peers.
118    async fn broadcast(&self, message: Message) -> Result<(), NetworkError>;
119
120    /// Get list of connected peers.
121    async fn connected_peers(&self) -> Vec<Peer>;
122
123    /// Subscribe to incoming messages.
124    fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)>;
125}
126
127/// Type alias for the message receiver shared across threads.
128type SharedMessageReceiver = Arc<RwLock<Option<mpsc::Receiver<(PeerId, Message)>>>>;
129/// Type alias for the outgoing message senders shared across threads.
130type SharedOutgoing = Arc<RwLock<HashMap<PeerId, mpsc::Sender<(PeerId, Message)>>>>;
131
132/// In-memory transport for testing and simulation.
133pub struct MemoryTransport {
134    local_id: PeerId,
135    peers: Arc<RwLock<HashMap<PeerId, Peer>>>,
136    message_tx: mpsc::Sender<(PeerId, Message)>,
137    message_rx: SharedMessageReceiver,
138    outgoing: SharedOutgoing,
139}
140
141impl MemoryTransport {
142    /// Create a new in-memory transport instance for a local peer.
143    pub fn new(local_id: PeerId) -> Self {
144        let (tx, rx) = mpsc::channel(100);
145        Self {
146            local_id,
147            peers: Arc::new(RwLock::new(HashMap::new())),
148            message_tx: tx,
149            message_rx: Arc::new(RwLock::new(Some(rx))),
150            outgoing: Arc::new(RwLock::new(HashMap::new())),
151        }
152    }
153
154    /// Return the local peer ID associated with this transport.
155    pub fn local_id(&self) -> &PeerId {
156        &self.local_id
157    }
158
159    /// Connect two memory transports together (for testing).
160    pub fn connect_to(&self, other: &MemoryTransport) {
161        // Add peer to our list
162        self.peers.write().insert(
163            other.local_id.clone(),
164            Peer {
165                id: other.local_id.clone(),
166                name: other.local_id.0.clone(),
167                state: PeerState::Connected,
168            },
169        );
170
171        // Set up channel to send to other
172        self.outgoing
173            .write()
174            .insert(other.local_id.clone(), other.message_tx.clone());
175
176        // Add us to other's peer list
177        other.peers.write().insert(
178            self.local_id.clone(),
179            Peer {
180                id: self.local_id.clone(),
181                name: self.local_id.0.clone(),
182                state: PeerState::Connected,
183            },
184        );
185
186        // Set up channel for other to send to us
187        other
188            .outgoing
189            .write()
190            .insert(self.local_id.clone(), self.message_tx.clone());
191    }
192}
193
194#[async_trait]
195impl NetworkTransport for MemoryTransport {
196    async fn connect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
197        self.peers.write().insert(
198            peer_id.clone(),
199            Peer {
200                id: peer_id.clone(),
201                name: peer_id.0.clone(),
202                state: PeerState::Connected,
203            },
204        );
205        Ok(())
206    }
207
208    async fn disconnect(&self, peer_id: &PeerId) -> Result<(), NetworkError> {
209        self.peers.write().remove(peer_id);
210        self.outgoing.write().remove(peer_id);
211        Ok(())
212    }
213
214    async fn send(&self, peer_id: &PeerId, message: Message) -> Result<(), NetworkError> {
215        let tx = {
216            let outgoing = self.outgoing.read();
217            outgoing.get(peer_id).cloned()
218        };
219
220        if let Some(tx) = tx {
221            tx.send((self.local_id.clone(), message))
222                .await
223                .map_err(|e| NetworkError::SendFailed(e.to_string()))?;
224            Ok(())
225        } else {
226            Err(NetworkError::PeerNotFound(peer_id.to_string()))
227        }
228    }
229
230    async fn broadcast(&self, message: Message) -> Result<(), NetworkError> {
231        let senders: Vec<_> = {
232            let outgoing = self.outgoing.read();
233            outgoing.values().cloned().collect()
234        };
235
236        for tx in senders {
237            let _ = tx.send((self.local_id.clone(), message.clone())).await;
238        }
239        Ok(())
240    }
241
242    async fn connected_peers(&self) -> Vec<Peer> {
243        self.peers.read().values().cloned().collect()
244    }
245
246    fn subscribe(&self) -> mpsc::Receiver<(PeerId, Message)> {
247        self.message_rx
248            .write()
249            .take()
250            .expect("subscribe can only be called once")
251    }
252}
253
254/// Create a fully connected in-memory transport network.
255///
256/// Each peer is connected to every other peer, making this helper ideal for
257/// deterministic tests and examples.
258pub fn create_network(count: usize) -> Vec<MemoryTransport> {
259    let transports: Vec<_> = (0..count)
260        .map(|i| MemoryTransport::new(PeerId::new(format!("peer-{}", i))))
261        .collect();
262
263    // Connect all peers to each other
264    for i in 0..count {
265        for j in (i + 1)..count {
266            transports[i].connect_to(&transports[j]);
267        }
268    }
269
270    transports
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[tokio::test]
278    async fn test_memory_transport() {
279        let transport1 = MemoryTransport::new(PeerId::new("peer-1"));
280        let transport2 = MemoryTransport::new(PeerId::new("peer-2"));
281
282        transport1.connect_to(&transport2);
283
284        let peers1 = transport1.connected_peers().await;
285        let peers2 = transport2.connected_peers().await;
286
287        assert_eq!(peers1.len(), 1);
288        assert_eq!(peers2.len(), 1);
289    }
290
291    #[tokio::test]
292    async fn test_network_creation() {
293        let network = create_network(3);
294        assert_eq!(network.len(), 3);
295
296        // Each peer should be connected to 2 others
297        for transport in &network {
298            let peers = transport.connected_peers().await;
299            assert_eq!(peers.len(), 2);
300        }
301    }
302}