1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{mpsc, RwLock};
7
8use crate::error::{MeshError, Result};
9use crate::message::{PeerCapability, PeerId, PeerMessage};
10
11pub struct PeerHandle {
13 pub receiver: mpsc::Receiver<PeerMessage>,
16 pub id: PeerId,
18}
19
20#[derive(Clone)]
23pub struct MeshHandle {
24 mesh: Arc<MeshInner>,
25 pub id: PeerId,
27}
28
29struct MeshInner {
30 peers: RwLock<HashMap<PeerId, PeerEntry>>,
31}
32
33struct PeerEntry {
34 sender: mpsc::Sender<PeerMessage>,
35 capabilities: Vec<PeerCapability>,
36 topics: Vec<String>,
38}
39
40const CHANNEL_CAPACITY: usize = 256;
41
42#[derive(Clone)]
44pub struct LocalMesh {
45 inner: Arc<MeshInner>,
46}
47
48impl LocalMesh {
49 pub fn new() -> Self {
51 Self {
52 inner: Arc::new(MeshInner {
53 peers: RwLock::new(HashMap::new()),
54 }),
55 }
56 }
57
58 pub async fn peer_count(&self) -> usize {
60 self.inner.peers.read().await.len()
61 }
62
63 pub async fn directory(&self) -> Vec<(PeerId, Vec<PeerCapability>)> {
65 self.inner
66 .peers
67 .read()
68 .await
69 .iter()
70 .map(|(id, entry)| (id.clone(), entry.capabilities.clone()))
71 .collect()
72 }
73
74 pub async fn join(
81 &self,
82 id: impl Into<PeerId>,
83 capabilities: Vec<PeerCapability>,
84 topics: Vec<String>,
85 ) -> Result<(PeerHandle, MeshHandle)> {
86 let id: PeerId = id.into();
87 let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
88 let mut peers = self.inner.peers.write().await;
89 peers.insert(
90 id.clone(),
91 PeerEntry {
92 sender: tx,
93 capabilities,
94 topics,
95 },
96 );
97 Ok((
98 PeerHandle {
99 receiver: rx,
100 id: id.clone(),
101 },
102 MeshHandle {
103 mesh: self.inner.clone(),
104 id,
105 },
106 ))
107 }
108
109 pub async fn leave(&self, id: &PeerId) -> Result<()> {
111 self.inner.peers.write().await.remove(id);
112 Ok(())
113 }
114}
115
116impl Default for LocalMesh {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl MeshHandle {
123 pub async fn publish(&self, msg: PeerMessage) -> Result<()> {
134 let peers = self.mesh.peers.read().await;
135 match &msg {
136 PeerMessage::Direct { to, .. } => {
137 let entry = peers
138 .get(to)
139 .ok_or_else(|| MeshError::UnknownPeer(to.clone()))?;
140 let _ = entry.sender.try_send(msg.clone());
141 }
142 PeerMessage::Broadcast { topic, .. } => {
143 for (id, entry) in peers.iter() {
144 if id == self.sender() {
145 continue;
146 }
147 if entry.topics.is_empty() || entry.topics.contains(topic) {
148 let _ = entry.sender.try_send(msg.clone());
149 }
150 }
151 }
152 _ => {
153 for (id, entry) in peers.iter() {
154 if id == self.sender() {
155 continue;
156 }
157 let _ = entry.sender.try_send(msg.clone());
158 }
159 }
160 }
161 Ok(())
162 }
163
164 pub fn sender(&self) -> &PeerId {
166 &self.id
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use serde_json::json;
174
175 fn caps(name: &str) -> Vec<PeerCapability> {
176 vec![PeerCapability {
177 name: name.into(),
178 version: None,
179 }]
180 }
181
182 #[tokio::test]
183 async fn join_and_leave_peers() {
184 let mesh = LocalMesh::new();
185 let (_h, handle_a) = mesh.join("a", caps("browser"), vec![]).await.unwrap();
186 let (_h, _handle_b) = mesh.join("b", caps("mirror"), vec![]).await.unwrap();
187 assert_eq!(mesh.peer_count().await, 2);
188 mesh.leave(&"a".to_string()).await.unwrap();
189 assert_eq!(mesh.peer_count().await, 1);
190 let _ = handle_a;
191 }
192
193 #[tokio::test]
194 async fn direct_message_routes_to_one_peer() {
195 let mesh = LocalMesh::new();
196 let (_h_a, handle_a) = mesh.join("a", caps("x"), vec![]).await.unwrap();
197 let (mut peer_b, _handle_b) = mesh.join("b", caps("x"), vec![]).await.unwrap();
198 let (mut peer_c, _handle_c) = mesh.join("c", caps("x"), vec![]).await.unwrap();
199
200 handle_a
201 .publish(PeerMessage::direct("a", "b", json!({"hello": 1})))
202 .await
203 .unwrap();
204
205 let msg = tokio::time::timeout(
206 std::time::Duration::from_millis(200),
207 peer_b.receiver.recv(),
208 )
209 .await
210 .unwrap()
211 .unwrap();
212 assert_eq!(msg.sender(), "a");
213
214 let no_msg = tokio::time::timeout(
215 std::time::Duration::from_millis(100),
216 peer_c.receiver.recv(),
217 )
218 .await;
219 assert!(
220 no_msg.is_err(),
221 "peer c should not have received the direct"
222 );
223 }
224
225 #[tokio::test]
226 async fn broadcast_filters_by_topic() {
227 let mesh = LocalMesh::new();
228 let (_pa, ha) = mesh.join("a", caps("x"), vec![]).await.unwrap();
229 let (mut pb, _hb) = mesh
230 .join("b", caps("x"), vec!["pets".into()])
231 .await
232 .unwrap();
233 let (mut pc, _hc) = mesh
234 .join("c", caps("x"), vec!["other".into()])
235 .await
236 .unwrap();
237
238 ha.publish(PeerMessage::broadcast("a", "pets", json!({"id": 1})))
239 .await
240 .unwrap();
241
242 assert!(
243 tokio::time::timeout(std::time::Duration::from_millis(200), pb.receiver.recv())
244 .await
245 .unwrap()
246 .is_some()
247 );
248 assert!(
249 tokio::time::timeout(std::time::Duration::from_millis(100), pc.receiver.recv())
250 .await
251 .is_err()
252 );
253 }
254
255 #[tokio::test]
256 async fn task_round_trip_uses_direct_or_broadcast() {
257 let mesh = LocalMesh::new();
258 let (_pa, ha) = mesh.join("a", caps("x"), vec![]).await.unwrap();
259 let (mut pb, hb) = mesh.join("b", caps("x"), vec![]).await.unwrap();
260 ha.publish(PeerMessage::task("a", json!({"do": "x"})))
261 .await
262 .unwrap();
263 let msg = pb.receiver.recv().await.unwrap();
264 let task_id = match &msg {
265 PeerMessage::Task { task_id, .. } => task_id.clone(),
266 _ => panic!("expected task"),
267 };
268 hb.publish(PeerMessage::Result {
269 from: "b".into(),
270 task_id,
271 result: json!({"ok": 1}),
272 ok: true,
273 })
274 .await
275 .unwrap();
276 }
277
278 #[tokio::test]
279 async fn directory_lists_capabilities() {
280 let mesh = LocalMesh::new();
281 mesh.join("a", caps("browser"), vec![]).await.unwrap();
282 mesh.join("b", caps("mirror"), vec![]).await.unwrap();
283 let dir = mesh.directory().await;
284 assert_eq!(dir.len(), 2);
285 }
286}