1use std::{collections::BTreeMap, ops::AddAssign, sync::Arc};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11
12use crate::ConnectionId;
13
14use super::PeerId;
15
16const MAX_HOPS: u8 = 6;
17
18#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
19pub struct PathMetric {
20 relay_hops: u8,
21 rtt_ms: u16,
22}
23
24impl From<(u8, u16)> for PathMetric {
25 fn from(value: (u8, u16)) -> Self {
26 Self { relay_hops: value.0, rtt_ms: value.1 }
27 }
28}
29
30#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
31pub struct RouterTableSync(Vec<(PeerId, PathMetric)>);
32
33#[derive(Debug, Default)]
34struct PeerMemory {
35 best: Option<ConnectionId>,
36 paths: BTreeMap<ConnectionId, PathMetric>,
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub enum RouteAction {
41 Local,
42 Next(ConnectionId),
43}
44
45#[derive(Debug)]
46struct RouterTable {
47 peer_id: PeerId,
48 peers: BTreeMap<PeerId, PeerMemory>,
49 directs: BTreeMap<ConnectionId, (PeerId, PathMetric)>,
50}
51
52impl RouterTable {
53 fn new(peer_id: PeerId) -> Self {
54 Self {
55 peer_id,
56 peers: Default::default(),
57 directs: Default::default(),
58 }
59 }
60
61 fn local_id(&self) -> PeerId {
62 self.peer_id
63 }
64
65 fn create_sync(&self, dest: &PeerId) -> RouterTableSync {
66 RouterTableSync(
67 self.peers
68 .iter()
69 .map(|(addr, history)| (*addr, history.best_metric().expect("should have best")))
70 .filter(|(addr, metric)| !dest.eq(addr) && metric.relay_hops <= MAX_HOPS)
71 .collect::<Vec<_>>(),
72 )
73 }
74
75 fn apply_sync(&mut self, conn: ConnectionId, sync: RouterTableSync) {
76 let (from_peer, direct_metric) = self.directs.get(&conn).expect("should have direct metric with apply_sync");
77 for (peer, _) in sync.0.iter() {
79 self.peers.entry(*peer).or_default();
80 }
81
82 let mut new_paths = BTreeMap::<PeerId, PathMetric>::from_iter(sync.0);
83 for (peer, memory) in self.peers.iter_mut().filter(|(p, _)| !from_peer.eq(p)) {
85 let previous = memory.paths.contains_key(&conn);
86 let current = new_paths.remove(peer);
87 match (previous, current) {
88 (true, Some(mut new_metric)) => {
89 new_metric += *direct_metric;
91 memory.paths.insert(conn, new_metric);
92 Self::select_best_for(peer, memory);
93 }
94 (true, None) => {
95 log::info!("[RouterTable] remove path for {peer}");
97 memory.paths.remove(&conn);
98 Self::select_best_for(peer, memory);
99 }
100 (false, Some(mut new_metric)) => {
101 log::info!("[RouterTable] create path for {peer}");
103 new_metric += *direct_metric;
104 memory.paths.insert(conn, new_metric);
105 Self::select_best_for(peer, memory);
106 }
107 _ => { }
109 }
110 }
111 self.peers.retain(|_k, v| v.best().is_some());
112 }
113
114 fn set_direct(&mut self, conn: ConnectionId, to: PeerId, ttl_ms: u16) {
115 self.directs.insert(conn, (to, (1, ttl_ms).into()));
116 let memory = self.peers.entry(to).or_default();
117 memory.paths.insert(conn, PathMetric { relay_hops: 0, rtt_ms: ttl_ms });
118 Self::select_best_for(&to, memory);
119 }
120
121 fn del_direct(&mut self, conn: &ConnectionId) {
122 if let Some((to, _)) = self.directs.remove(conn) {
123 if let Some(memory) = self.peers.get_mut(&to) {
124 memory.paths.remove(conn);
125 Self::select_best_for(&to, memory);
126 if memory.best().is_none() {
127 self.peers.remove(&to);
128 }
129 }
130 }
131
132 for (peer, memory) in self.peers.iter_mut() {
134 if memory.paths.remove(conn).is_some() {
135 Self::select_best_for(peer, memory);
136 }
137 }
138 self.peers.retain(|_k, v| v.best().is_some());
139 }
140
141 fn action(&self, dest: &PeerId) -> Option<RouteAction> {
142 if self.peer_id.eq(dest) {
143 Some(RouteAction::Local)
144 } else {
145 self.peers.get(dest)?.best().map(RouteAction::Next)
146 }
147 }
148
149 fn next_remote(&self, next: &PeerId) -> Option<(ConnectionId, PathMetric)> {
151 let memory = self.peers.get(next)?;
152 let best = memory.best()?;
153 let metric = memory.best_metric().expect("should have metric");
154 Some((best, metric))
155 }
156
157 fn select_best_for(dest: &PeerId, memory: &mut PeerMemory) {
158 if let Some((new_best, metric)) = memory.select_best() {
159 log::info!(
160 "[RouterTable] to {dest} select new path over {new_best} with rtt {} ms over {} hop(s)",
161 metric.rtt_ms,
162 metric.relay_hops
163 );
164 }
165 }
166
167 fn neighbours(&self) -> Vec<(ConnectionId, PeerId, u16)> {
168 self.directs.iter().map(|(k, (peer, v))| (*k, *peer, v.rtt_ms)).collect()
169 }
170}
171
172impl PathMetric {
173 fn score(&self) -> u16 {
174 self.rtt_ms + self.relay_hops as u16 * 10
175 }
176}
177
178impl AddAssign for PathMetric {
179 fn add_assign(&mut self, rhs: Self) {
180 self.relay_hops += rhs.relay_hops;
181 self.rtt_ms += rhs.rtt_ms;
182 }
183}
184
185impl PeerMemory {
186 fn best(&self) -> Option<ConnectionId> {
187 self.best
188 }
189
190 fn best_metric(&self) -> Option<PathMetric> {
191 self.best.map(|p| *self.paths.get(&p).expect("should have metric with best path"))
192 }
193
194 fn select_best(&mut self) -> Option<(ConnectionId, PathMetric)> {
195 let previous = self.best;
196 self.best = None;
197 let mut iter = self.paths.iter();
198 let (peer, metric) = iter.next()?;
199 let mut best_peer = peer;
200 let mut best_score = metric.score();
201
202 for (peer, metric) in iter {
203 if best_score > metric.score() {
204 best_peer = peer;
205 best_score = metric.score();
206 }
207 }
208
209 self.best = Some(*best_peer);
210 if self.best != previous {
211 let metric = self.best_metric().expect("should have best metric after select success");
212 Some((*best_peer, metric))
213 } else {
214 None
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct SharedRouterTable {
221 table: Arc<RwLock<RouterTable>>,
222}
223
224impl SharedRouterTable {
225 pub fn new(address: PeerId) -> Self {
226 Self {
227 table: Arc::new(RwLock::new(RouterTable::new(address))),
228 }
229 }
230
231 pub fn local_id(&self) -> PeerId {
232 self.table.read().local_id()
233 }
234
235 pub fn create_sync(&self, dest: &PeerId) -> RouterTableSync {
236 self.table.read().create_sync(dest)
237 }
238
239 pub fn apply_sync(&self, conn: ConnectionId, sync: RouterTableSync) {
240 self.table.write().apply_sync(conn, sync);
241 }
242
243 pub fn set_direct(&self, conn: ConnectionId, to: PeerId, ttl_ms: u16) {
244 self.table.write().set_direct(conn, to, ttl_ms);
245 }
246
247 pub fn del_direct(&self, conn: &ConnectionId) {
248 self.table.write().del_direct(conn);
249 }
250
251 pub fn action(&self, dest: &PeerId) -> Option<RouteAction> {
252 self.table.read().action(dest)
253 }
254
255 pub fn next_remote(&self, dest: &PeerId) -> Option<(ConnectionId, PathMetric)> {
256 self.table.read().next_remote(dest)
257 }
258
259 pub fn neighbours(&self) -> Vec<(ConnectionId, PeerId, u16)> {
260 self.table.read().neighbours()
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use crate::{router::RouterTableSync, ConnectionId, PeerId};
267
268 use super::{RouteAction, RouterTable, MAX_HOPS};
269
270 #[test_log::test]
271 fn route_local() {
272 let table = RouterTable::new(PeerId(0));
273 assert_eq!(table.action(&PeerId(0)), Some(RouteAction::Local));
274 }
275
276 #[test_log::test]
277 fn create_correct_direct_sync() {
278 let mut table = RouterTable::new(PeerId(0));
279
280 let peer1 = PeerId(1);
281 let conn1 = ConnectionId(1);
282 let peer2 = PeerId(2);
283 let conn2 = ConnectionId(2);
284 let peer3 = PeerId(3);
285
286 table.set_direct(conn1, peer1, 100);
287 table.set_direct(conn2, peer2, 200);
288
289 assert_eq!(table.next_remote(&peer1), Some((conn1, (0, 100).into())));
290 assert_eq!(table.next_remote(&peer2), Some((conn2, (0, 200).into())));
291 assert_eq!(table.next_remote(&peer3), None);
292
293 assert_eq!(table.create_sync(&peer1), RouterTableSync(vec![(peer2, (0, 200).into())]));
294 assert_eq!(table.create_sync(&peer2), RouterTableSync(vec![(peer1, (0, 100).into())]));
295 }
296
297 #[test_log::test]
298 fn apply_correct_direct_sync() {
299 let mut table = RouterTable::new(PeerId(0));
300
301 let peer1 = PeerId(1);
302 let conn1 = ConnectionId(1);
303 let peer2 = PeerId(2);
304 let peer3 = PeerId(3);
305 let peer4 = PeerId(4);
306 let conn4 = ConnectionId(4);
307
308 table.set_direct(conn1, peer1, 100);
309 table.set_direct(conn4, peer4, 400);
310
311 table.apply_sync(conn1, RouterTableSync(vec![(peer2, (0, 200).into())]));
312
313 assert_eq!(table.next_remote(&peer1), Some((conn1, (0, 100).into())));
315 assert_eq!(table.next_remote(&peer2), Some((conn1, (1, 300).into())));
316 assert_eq!(table.next_remote(&peer3), None);
317
318 assert_eq!(table.create_sync(&peer1), RouterTableSync(vec![(peer2, (1, 300).into()), (peer4, (0, 400).into())]));
320 assert_eq!(table.create_sync(&peer4), RouterTableSync(vec![(peer1, (0, 100).into()), (peer2, (1, 300).into())]));
321 }
322
323 #[test_log::test]
324 fn dont_create_sync_over_max_hops() {
325 let mut table = RouterTable::new(PeerId(0));
326
327 let peer1 = PeerId(1);
328 let conn1 = ConnectionId(1);
329 let peer2 = PeerId(2);
330 let peer3 = PeerId(3);
331 let conn3 = ConnectionId(3);
332
333 table.set_direct(conn1, peer1, 100);
334 table.set_direct(conn3, peer3, 300);
335
336 table.apply_sync(conn1, RouterTableSync(vec![(peer2, (MAX_HOPS, 200).into())]));
337 assert_eq!(table.next_remote(&peer2), Some((conn1, (MAX_HOPS + 1, 300).into())));
338
339 assert_eq!(table.create_sync(&peer3), RouterTableSync(vec![(peer1, (0, 100).into())]));
341 }
342
343 #[test_log::test]
344 fn should_remove_relay_path_after_disconnect() {
345 let mut table = RouterTable::new(PeerId(0));
346
347 let peer1 = PeerId(1);
348 let conn1 = ConnectionId(1);
349
350 let peer2 = PeerId(2);
351
352 table.set_direct(conn1, peer1, 100);
353
354 table.apply_sync(conn1, RouterTableSync(vec![(peer2, (MAX_HOPS, 200).into())]));
355 assert_eq!(table.next_remote(&peer2), Some((conn1, (MAX_HOPS + 1, 300).into())));
356
357 table.del_direct(&conn1);
359
360 assert_eq!(table.next_remote(&peer2), None);
362 }
363}