Skip to main content

oxide_mesh/
local.rs

1//! In-process mesh built on `tokio::sync::mpsc` channels.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{mpsc, RwLock};
7
8use crate::error::{MeshError, Result};
9use crate::message::{PeerCapability, PeerId, PeerMessage};
10
11/// Per-peer channel handle.
12pub struct PeerHandle {
13    /// Receiver for messages routed to this peer (Direct + Broadcast +
14    /// Tasks). Tests / agents pop from here.
15    pub receiver: mpsc::Receiver<PeerMessage>,
16    /// Peer id.
17    pub id: PeerId,
18}
19
20/// Shared handle returned by [`LocalMesh::join`]. Keeps the channel alive and
21/// lets the peer publish back into the mesh.
22#[derive(Clone)]
23pub struct MeshHandle {
24    mesh: Arc<MeshInner>,
25    /// Peer id.
26    pub id: PeerId,
27}
28
29struct MeshInner {
30    peers: RwLock<HashMap<PeerId, PeerEntry>>,
31}
32
33struct PeerEntry {
34    sender: mpsc::Sender<PeerMessage>,
35    capabilities: Vec<PeerCapability>,
36    /// Topics this peer subscribes to (empty = all).
37    topics: Vec<String>,
38}
39
40const CHANNEL_CAPACITY: usize = 256;
41
42/// Default in-process mesh. Cheaply cloneable.
43#[derive(Clone)]
44pub struct LocalMesh {
45    inner: Arc<MeshInner>,
46}
47
48impl LocalMesh {
49    /// Build an empty mesh.
50    pub fn new() -> Self {
51        Self {
52            inner: Arc::new(MeshInner {
53                peers: RwLock::new(HashMap::new()),
54            }),
55        }
56    }
57
58    /// Number of joined peers.
59    pub async fn peer_count(&self) -> usize {
60        self.inner.peers.read().await.len()
61    }
62
63    /// Snapshot of every peer's capabilities.
64    pub async fn directory(&self) -> Vec<(PeerId, Vec<PeerCapability>)> {
65        self.inner
66            .peers
67            .read()
68            .await
69            .iter()
70            .map(|(id, entry)| (id.clone(), entry.capabilities.clone()))
71            .collect()
72    }
73
74    /// Join the mesh as `id` with the given capabilities. Returns a
75    /// [`PeerHandle`] (the receiver end) and a [`MeshHandle`] used to
76    /// publish.
77    ///
78    /// Subscribing to `topics` is optional; an empty list means "receive all
79    /// broadcasts".
80    pub async fn join(
81        &self,
82        id: impl Into<PeerId>,
83        capabilities: Vec<PeerCapability>,
84        topics: Vec<String>,
85    ) -> Result<(PeerHandle, MeshHandle)> {
86        let id: PeerId = id.into();
87        let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
88        let mut peers = self.inner.peers.write().await;
89        peers.insert(
90            id.clone(),
91            PeerEntry {
92                sender: tx,
93                capabilities,
94                topics,
95            },
96        );
97        Ok((
98            PeerHandle {
99                receiver: rx,
100                id: id.clone(),
101            },
102            MeshHandle {
103                mesh: self.inner.clone(),
104                id,
105            },
106        ))
107    }
108
109    /// Disconnect a peer.
110    pub async fn leave(&self, id: &PeerId) -> Result<()> {
111        self.inner.peers.write().await.remove(id);
112        Ok(())
113    }
114}
115
116impl Default for LocalMesh {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl MeshHandle {
123    /// Publish a message to whichever peers it is destined for.
124    ///
125    /// Routing rules:
126    /// - [`PeerMessage::Hello`] / [`PeerMessage::Result`] go to every peer
127    ///   except the sender.
128    /// - [`PeerMessage::Direct`] goes to `to` only.
129    /// - [`PeerMessage::Broadcast`] goes to every peer subscribed to the
130    ///   topic (empty subscription list = all topics).
131    /// - [`PeerMessage::Task`] goes to every peer except the sender — first
132    ///   to claim it wins, by convention.
133    pub async fn publish(&self, msg: PeerMessage) -> Result<()> {
134        let peers = self.mesh.peers.read().await;
135        match &msg {
136            PeerMessage::Direct { to, .. } => {
137                let entry = peers
138                    .get(to)
139                    .ok_or_else(|| MeshError::UnknownPeer(to.clone()))?;
140                let _ = entry.sender.try_send(msg.clone());
141            }
142            PeerMessage::Broadcast { topic, .. } => {
143                for (id, entry) in peers.iter() {
144                    if id == self.sender() {
145                        continue;
146                    }
147                    if entry.topics.is_empty() || entry.topics.contains(topic) {
148                        let _ = entry.sender.try_send(msg.clone());
149                    }
150                }
151            }
152            _ => {
153                for (id, entry) in peers.iter() {
154                    if id == self.sender() {
155                        continue;
156                    }
157                    let _ = entry.sender.try_send(msg.clone());
158                }
159            }
160        }
161        Ok(())
162    }
163
164    /// Sender's peer id.
165    pub fn sender(&self) -> &PeerId {
166        &self.id
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use serde_json::json;
174
175    fn caps(name: &str) -> Vec<PeerCapability> {
176        vec![PeerCapability {
177            name: name.into(),
178            version: None,
179        }]
180    }
181
182    #[tokio::test]
183    async fn join_and_leave_peers() {
184        let mesh = LocalMesh::new();
185        let (_h, handle_a) = mesh.join("a", caps("browser"), vec![]).await.unwrap();
186        let (_h, _handle_b) = mesh.join("b", caps("mirror"), vec![]).await.unwrap();
187        assert_eq!(mesh.peer_count().await, 2);
188        mesh.leave(&"a".to_string()).await.unwrap();
189        assert_eq!(mesh.peer_count().await, 1);
190        let _ = handle_a;
191    }
192
193    #[tokio::test]
194    async fn direct_message_routes_to_one_peer() {
195        let mesh = LocalMesh::new();
196        let (_h_a, handle_a) = mesh.join("a", caps("x"), vec![]).await.unwrap();
197        let (mut peer_b, _handle_b) = mesh.join("b", caps("x"), vec![]).await.unwrap();
198        let (mut peer_c, _handle_c) = mesh.join("c", caps("x"), vec![]).await.unwrap();
199
200        handle_a
201            .publish(PeerMessage::direct("a", "b", json!({"hello": 1})))
202            .await
203            .unwrap();
204
205        let msg = tokio::time::timeout(
206            std::time::Duration::from_millis(200),
207            peer_b.receiver.recv(),
208        )
209        .await
210        .unwrap()
211        .unwrap();
212        assert_eq!(msg.sender(), "a");
213
214        let no_msg = tokio::time::timeout(
215            std::time::Duration::from_millis(100),
216            peer_c.receiver.recv(),
217        )
218        .await;
219        assert!(
220            no_msg.is_err(),
221            "peer c should not have received the direct"
222        );
223    }
224
225    #[tokio::test]
226    async fn broadcast_filters_by_topic() {
227        let mesh = LocalMesh::new();
228        let (_pa, ha) = mesh.join("a", caps("x"), vec![]).await.unwrap();
229        let (mut pb, _hb) = mesh
230            .join("b", caps("x"), vec!["pets".into()])
231            .await
232            .unwrap();
233        let (mut pc, _hc) = mesh
234            .join("c", caps("x"), vec!["other".into()])
235            .await
236            .unwrap();
237
238        ha.publish(PeerMessage::broadcast("a", "pets", json!({"id": 1})))
239            .await
240            .unwrap();
241
242        assert!(
243            tokio::time::timeout(std::time::Duration::from_millis(200), pb.receiver.recv())
244                .await
245                .unwrap()
246                .is_some()
247        );
248        assert!(
249            tokio::time::timeout(std::time::Duration::from_millis(100), pc.receiver.recv())
250                .await
251                .is_err()
252        );
253    }
254
255    #[tokio::test]
256    async fn task_round_trip_uses_direct_or_broadcast() {
257        let mesh = LocalMesh::new();
258        let (_pa, ha) = mesh.join("a", caps("x"), vec![]).await.unwrap();
259        let (mut pb, hb) = mesh.join("b", caps("x"), vec![]).await.unwrap();
260        ha.publish(PeerMessage::task("a", json!({"do": "x"})))
261            .await
262            .unwrap();
263        let msg = pb.receiver.recv().await.unwrap();
264        let task_id = match &msg {
265            PeerMessage::Task { task_id, .. } => task_id.clone(),
266            _ => panic!("expected task"),
267        };
268        hb.publish(PeerMessage::Result {
269            from: "b".into(),
270            task_id,
271            result: json!({"ok": 1}),
272            ok: true,
273        })
274        .await
275        .unwrap();
276    }
277
278    #[tokio::test]
279    async fn directory_lists_capabilities() {
280        let mesh = LocalMesh::new();
281        mesh.join("a", caps("browser"), vec![]).await.unwrap();
282        mesh.join("b", caps("mirror"), vec![]).await.unwrap();
283        let dir = mesh.directory().await;
284        assert_eq!(dir.len(), 2);
285    }
286}