1use std::{
2 convert::Infallible,
3 sync::{Arc, Weak},
4 time::{Duration, Instant},
5};
6
7use anyhow::Context;
8use async_trait::async_trait;
9use futures_util::TryFutureExt;
10use itertools::Itertools;
11use nanorpc::{DynRpcTransport, OrService, RpcService, RpcTransport};
12use smol::{lock::RwLock, Task};
13use smol_timeout::TimeoutExt;
14
15use crate::{
16 protocol::{Address, ControlClient, ControlProtocol, ControlService},
17 Backhaul,
18};
19
20use self::routedb::RouteDb;
21
22const ROUTE_LIMIT: usize = 32;
23
24pub struct Swarm<B: Backhaul, C> {
28 haul: Arc<B>,
30 routes: Arc<smol::lock::RwLock<RouteDb>>,
32 open_client: Arc<dyn Fn(DynRpcTransport) -> C + Sync + Send + 'static>,
34 swarm_id: String,
36
37 _route_maintain_task: Arc<Task<Infallible>>,
39}
40
41impl<B: Backhaul, C> Clone for Swarm<B, C> {
42 fn clone(&self) -> Self {
43 Self {
44 haul: self.haul.clone(),
45 routes: self.routes.clone(),
46 open_client: self.open_client.clone(),
47 swarm_id: self.swarm_id.clone(),
48 _route_maintain_task: self._route_maintain_task.clone(),
49 }
50 }
51}
52
53impl<B: Backhaul, C: 'static> Swarm<B, C>
54where
55 <B::RpcTransport as RpcTransport>::Error: std::error::Error + Send + Sync,
56{
57 pub fn new(
59 backhaul: B,
60 client_map_fn: impl Fn(DynRpcTransport) -> C + Sync + Send + 'static,
61 swarm_id: &str,
62 ) -> Self {
63 let haul = Arc::new(backhaul);
64 let routes = Arc::new(smol::lock::RwLock::new(RouteDb::default()));
65 let open_client = Arc::new(client_map_fn);
66 Self {
67 haul: haul.clone(),
68 routes: routes.clone(),
69 open_client: open_client.clone(),
70 swarm_id: swarm_id.to_string(),
71
72 _route_maintain_task: smolscale::spawn(Self::route_maintain(
73 haul,
74 routes,
75 open_client,
76 swarm_id.to_string(),
77 ))
78 .into(),
79 }
80 }
81
82 pub async fn connect(&self, addr: Address) -> Result<C, B::ConnectError> {
84 Ok((self.open_client)(DynRpcTransport::new(
85 self.haul.connect(addr).await?,
86 )))
87 }
88
89 pub async fn connect_lazy(&self, addr: Address) -> Result<C, B::ConnectError> {
91 Ok((self.open_client)(DynRpcTransport::new(
92 self.haul.connect_lazy(addr).await,
93 )))
94 }
95
96 pub async fn start_listen(
98 &self,
99 listen_addr: Address,
100 advertise_addr: Option<Address>,
101 service: impl RpcService,
102 ) -> Result<(), B::ListenError> {
103 self.haul
104 .start_listen(
105 listen_addr,
106 OrService::new(
107 service,
108 ControlService(ControlProtocolImpl {
109 swarm_id: self.swarm_id.clone(),
110 routes: self.routes.clone(),
111 weak_haul: Arc::downgrade(&self.haul),
112 }),
113 ),
114 )
115 .await?;
116 if let Some(advertise_addr) = advertise_addr {
117 self.routes
118 .write()
119 .await
120 .insert(advertise_addr, Duration::from_millis(0), true)
121 }
122 Ok(())
123 }
124
125 pub async fn routes(&self) -> Vec<Address> {
127 self.routes
128 .read()
129 .await
130 .random_iter()
131 .map(|s| s.addr)
132 .collect_vec()
133 }
134
135 pub async fn add_route(&self, addr: Address, sticky: bool) {
137 let mut v = self.routes.write().await;
138 v.insert(addr, Duration::from_secs(1), sticky);
139 }
140
141 #[allow(clippy::redundant_locals)]
143 async fn route_maintain(
144 haul: Arc<B>,
145 routes: Arc<smol::lock::RwLock<RouteDb>>,
146 _open_client: Arc<dyn Fn(DynRpcTransport) -> C + Sync + Send + 'static>,
147 swarm_id: String,
148 ) -> Infallible {
149 const PULSE: Duration = Duration::from_secs(1);
150 let mut timer = smol::Timer::interval(PULSE);
151 let exec = smol::Executor::new();
152 exec.run(async {
154 loop {
155 if fastrand::f64() * 3.0 < 1.0 {
156 log::debug!("[{swarm_id}] push pulse");
157 if let Some(random) = routes.read().await.random_iter().next() {
158 if let Some(to_send) = routes
159 .read()
160 .await
161 .random_iter()
162 .find(|r| r.addr != random.addr)
163 {
164 let random = random.addr;
165 let to_send = to_send.addr;
166 log::debug!("[{swarm_id}] push {to_send} => {random}");
167 let random2 = random.clone();
168 exec.spawn(
169 async {
170 let to_send = to_send;
171 let random = random;
172 let conn = ControlClient(
173 haul.connect(random)
174 .timeout(Duration::from_secs(60))
175 .await
176 .context("connect timeout")??,
177 );
178 conn.__mn_advertise_peer(to_send)
179 .timeout(Duration::from_secs(60))
180 .await
181 .context("advertise timeout")??;
182 anyhow::Ok(())
183 }
184 .unwrap_or_else(|e| {
185 let random2 = random2;
186 log::warn!("[{swarm_id}] push failed to {}: {e}", random2)
187 }),
188 )
189 .detach();
190 }
191 }
192 }
193
194 if fastrand::f64() * 10.0 < 1.0 {
195 let current_count = routes.read().await.count();
196 log::debug!("[{swarm_id}] pull pulse {current_count}/{ROUTE_LIMIT}");
197 if current_count < ROUTE_LIMIT {
198 if let Some(route) = routes.read().await.random_iter().next() {
200 log::debug!("[{swarm_id}] getting more routes from {}", route.addr);
201 let route2 = route.clone();
202 exec.spawn(
203 async {
204 let route = route;
205 let conn = ControlClient(
206 haul.connect(route.addr.clone())
207 .timeout(Duration::from_secs(60))
208 .await
209 .context("connect timeout")??,
210 );
211 for peer in conn
212 .__mn_get_random_peers()
213 .timeout(Duration::from_secs(60))
214 .await
215 .context("get peers timeout")??
216 {
217 log::debug!(
218 "[{swarm_id}] got route {} from {}",
219 route.addr,
220 peer
221 );
222 let ping = test_ping(&haul, peer.clone(), &swarm_id)
223 .await
224 .context("ping failed")?;
225 routes.write().await.insert(peer, ping, false)
226 }
227 anyhow::Ok(())
228 }
229 .unwrap_or_else(|e| {
230 let route = route2;
231 log::warn!(
232 "[{swarm_id}] get more routes failed from {}: {e}",
233 route.addr
234 )
235 }),
236 )
237 .detach();
238 } else if current_count > ROUTE_LIMIT {
239 let mut routes = routes.write().await;
241 routes.remove_worst();
242 }
243 }
244 }
245 if fastrand::f64() * 300.0 < 1.0 {
246 log::debug!("[{swarm_id}] ping pulse...");
247 let routes_guard = routes.read().await;
248 for route in routes_guard.random_iter() {
249 exec.spawn(async {
250 let route = route;
251 if let Err(err) = test_ping(&haul, route.addr.clone(), &swarm_id).await
252 {
253 if route.sticky {
254 log::debug!(
255 "[{swarm_id}] keeping sticky {} despite ping-fail: {err}",
256 route.addr
257 );
258 } else {
259 log::debug!(
260 "[{swarm_id}] ping-failing non-sticky {}: {err}",
261 route.addr
262 );
263 routes.write().await.remove(route.addr);
264 }
265 }
266 })
267 .detach();
268 }
269 }
270 (&mut timer).await;
271 }
272 })
273 .await
274 }
275}
276
277struct ControlProtocolImpl<B: Backhaul> {
278 swarm_id: String,
279 routes: Arc<RwLock<RouteDb>>,
280 weak_haul: Weak<B>,
281}
282
283#[async_trait]
284impl<B: Backhaul> ControlProtocol for ControlProtocolImpl<B>
285where
286 <B::RpcTransport as RpcTransport>::Error: std::error::Error + Send + Sync,
287{
288 async fn __mn_get_swarm_id(&self) -> String {
289 self.swarm_id.clone()
290 }
291
292 async fn __mn_get_random_peers(&self) -> Vec<Address> {
293 self.routes
294 .read()
295 .await
296 .random_iter()
297 .take(8)
298 .map(|r| r.addr)
299 .collect()
300 }
301
302 async fn __mn_advertise_peer(&self, addr: Address) -> bool {
303 if self.routes.read().await.count() >= ROUTE_LIMIT {
304 return false;
305 }
306 if let Some(haul) = self.weak_haul.upgrade() {
307 if let Ok(ping) = test_ping(&haul, addr.clone(), &self.swarm_id).await {
308 self.routes.write().await.insert(addr, ping, false);
309 }
310 }
311 return true;
312 }
313}
314
315async fn test_ping<B: Backhaul>(
316 haul: &Arc<B>,
317 addr: Address,
318 swarm_id: &str,
319) -> anyhow::Result<Duration>
320where
321 <B::RpcTransport as RpcTransport>::Error: std::error::Error + Send + Sync,
322{
323 let start = Instant::now();
324 let client = ControlClient(
325 haul.connect(addr)
326 .timeout(Duration::from_secs(5))
327 .await
328 .context("connect timed out after 5 seconds")??,
329 );
330 let their_swarm_id = client
331 .__mn_get_swarm_id()
332 .timeout(Duration::from_secs(5))
333 .await
334 .context("ping timed out after 5 seconds")??;
335 if their_swarm_id != swarm_id {
336 anyhow::bail!(
337 "their swarm ID {:?} is not our swarm ID {:?}",
338 their_swarm_id,
339 swarm_id
340 );
341 }
342 Ok(start.elapsed())
343}
344
345mod routedb {
346 use itertools::Itertools;
347
348 use super::*;
349 #[derive(Default)]
350 pub struct RouteDb {
351 routes: Vec<Route>,
353 }
354
355 impl RouteDb {
356 pub fn get_route(&self, addr: Address) -> Option<Route> {
357 self.routes.iter().find(|r| r.addr == addr).cloned()
358 }
359
360 fn get_route_mut(&mut self, addr: Address) -> Option<&mut Route> {
361 self.routes.iter_mut().find(|r| r.addr == addr)
362 }
363
364 pub fn insert(&mut self, addr: Address, ping: Duration, sticky: bool) {
365 if let Some(r) = self.get_route_mut(addr.clone()) {
366 r.last_ping = ping;
367 r.last_seen = Instant::now();
368 r.sticky = sticky;
369 } else {
370 self.routes.push(Route {
371 addr,
372 last_ping: ping,
373 last_seen: Instant::now(),
374 sticky,
375 })
376 }
377 }
378
379 pub fn random_iter(&self) -> impl Iterator<Item = Route> + '_ {
380 std::iter::repeat_with(|| fastrand::usize(0..self.routes.len()))
381 .map(|i| self.routes[i].clone())
382 .unique()
383 .take(self.routes.len())
384 }
385
386 pub fn remove(&mut self, addr: Address) {
387 self.routes.retain(|s| s.addr != addr);
388 }
389
390 pub fn remove_worst(&mut self) {
391 self.routes.sort_unstable_by_key(|s| s.last_ping);
392 let _ = self.routes.pop();
393 }
394
395 pub fn count(&self) -> usize {
396 self.routes.len()
397 }
398 }
399
400 #[derive(Clone, Hash, PartialEq, Eq)]
401 pub struct Route {
402 pub addr: Address,
403 pub last_ping: Duration,
404 pub last_seen: Instant,
405 pub sticky: bool,
406 }
407}