1use std::{
2 cmp::Ordering,
3 collections::{HashMap, HashSet},
4 net::SocketAddr,
5};
6
7use rand::{seq::IteratorRandom, thread_rng, Fill};
8use time::OffsetDateTime;
9
10use crate::core::{
11 id::Id,
12 message::{Chunk, FindKNodes, KNodes, Message, Ping, Pong, Response},
13 routing_table::RoutingTable,
14 traits::ProcessData,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ConnState {
19 Connected,
20 Disconnected,
21}
22
23#[derive(Debug, Clone, Copy)]
24pub struct TcpMeta {
25 pub(crate) listening_addr: SocketAddr,
26 pub(crate) conn_addr: Option<SocketAddr>,
27 pub(crate) conn_state: ConnState,
28 pub(crate) last_seen: Option<OffsetDateTime>,
29}
30
31impl TcpMeta {
32 fn new(
33 listening_addr: SocketAddr,
34 conn_addr: Option<SocketAddr>,
35 conn_state: ConnState,
36 last_seen: Option<OffsetDateTime>,
37 ) -> Self {
38 Self {
39 listening_addr,
40 conn_addr,
41 conn_state,
42 last_seen,
43 }
44 }
45}
46
47type ConnAddr = SocketAddr;
48
49#[derive(Debug, Clone)]
51pub struct TcpRouter {
52 rt: RoutingTable<ConnAddr, TcpMeta>,
54}
55
56impl Default for TcpRouter {
57 fn default() -> Self {
58 let mut rng = thread_rng();
59 let mut bytes = [0u8; Id::BYTES];
60 debug_assert!(bytes.try_fill(&mut rng).is_ok());
61
62 Self {
63 rt: RoutingTable {
64 local_id: Id::new(bytes),
65 max_bucket_size: 20,
66 k: 20,
67 buckets: HashMap::new(),
68 peer_list: HashMap::new(),
69 id_list: HashMap::new(),
70 },
71 }
72 }
73}
74
75impl TcpRouter {
76 pub fn new(local_id: Id, max_bucket_size: u8, k: u8) -> Self {
78 Self {
79 rt: RoutingTable {
80 local_id,
81 max_bucket_size,
82 k,
83 buckets: HashMap::new(),
84 peer_list: HashMap::new(),
85 id_list: HashMap::new(),
86 },
87 }
88 }
89
90 #[cfg(feature = "sync")]
91 pub(crate) fn routing_table(&self) -> &RoutingTable<ConnAddr, TcpMeta> {
92 &self.rt
93 }
94
95 pub fn local_id(&self) -> Id {
97 self.rt.local_id
98 }
99
100 pub fn peer_id(&self, addr: SocketAddr) -> Option<Id> {
102 self.rt.id_list.get(&addr).copied()
103 }
104
105 pub fn peer_meta(&self, id: &Id) -> Option<&TcpMeta> {
107 self.rt.peer_list.get(id)
108 }
109
110 pub fn insert(&mut self, id: Id, listening_addr: SocketAddr) -> bool {
113 if id == self.rt.local_id {
132 return false;
133 }
134
135 self.rt
136 .peer_list
137 .entry(id)
138 .and_modify(|meta| {
139 meta.listening_addr = listening_addr;
140 })
141 .or_insert_with(|| TcpMeta::new(listening_addr, None, ConnState::Disconnected, None));
142
143 self.rt.id_list.insert(listening_addr, id);
144
145 true
146 }
147
148 pub fn can_connect(&mut self, id: Id) -> (bool, Option<u32>) {
151 let i = match self.local_id().log2_distance(&id) {
152 Some(i) => i,
153 None => return (false, None),
154 };
155
156 let bucket = self.rt.buckets.entry(i).or_insert_with(HashSet::new);
159 match bucket.len().cmp(&self.rt.max_bucket_size.into()) {
160 Ordering::Less => {
161 (true, Some(i))
164 }
165 Ordering::Equal => {
166 (false, None)
168 }
169 Ordering::Greater => {
170 unreachable!()
172 }
173 }
174 }
175
176 pub fn set_connected(&mut self, id: Id, conn_addr: SocketAddr) -> bool {
179 match self.can_connect(id) {
180 (true, Some(i)) => {
181 if let (Some(peer_meta), Some(bucket)) =
182 (self.rt.peer_list.get_mut(&id), self.rt.buckets.get_mut(&i))
183 {
184 let _res = bucket.insert(id);
187 debug_assert!(_res);
188
189 self.rt.id_list.insert(conn_addr, id);
190 peer_meta.conn_addr = Some(conn_addr);
191 peer_meta.conn_state = ConnState::Connected;
192 peer_meta.last_seen = Some(OffsetDateTime::now_utc());
193 }
194
195 true
196 }
197
198 _ => false,
199 }
200 }
201
202 pub fn set_disconnected(&mut self, conn_addr: SocketAddr) -> bool {
205 let id = match self.peer_id(conn_addr) {
206 Some(id) => id,
207 None => return false,
208 };
209
210 let i = self
211 .local_id()
212 .log2_distance(&id)
213 .expect("self can't have an identifier in the peer list");
214
215 if let (Some(peer_meta), Some(bucket)) =
216 (self.rt.peer_list.get_mut(&id), self.rt.buckets.get_mut(&i))
217 {
218 let bucket_res = bucket.remove(&id);
219 debug_assert!(bucket_res);
220
221 let id_list_res = self
225 .rt
226 .id_list
227 .remove(&peer_meta.conn_addr.expect("conn_addr must be present"));
228 debug_assert!(id_list_res.is_some());
229
230 peer_meta.conn_addr = None;
231 peer_meta.conn_state = ConnState::Disconnected;
232
233 return bucket_res && id_list_res.is_some();
234 }
235
236 false
237 }
238
239 pub fn select_broadcast_peers(&self, height: u32) -> Option<Vec<(u32, SocketAddr)>> {
242 let mut rng = thread_rng();
243
244 if height == 0 {
246 return None;
247 }
248
249 let mut selected_peers = vec![];
250 for h in 0..height {
251 if let Some(bucket) = self.rt.buckets.get(&h) {
252 if let Some(id) = bucket.iter().choose(&mut rng) {
254 let peer_meta = self.rt.peer_list.get(id);
256 debug_assert!(peer_meta.is_some());
257 debug_assert_eq!(peer_meta.unwrap().conn_state, ConnState::Connected);
258 debug_assert!(peer_meta.unwrap().conn_addr.is_some());
261 let addr = peer_meta.unwrap().conn_addr.unwrap();
262
263 selected_peers.push((h, addr))
264 }
265 }
266 }
267
268 Some(selected_peers)
269 }
270
271 fn find_k_closest(&self, id: &Id, k: usize) -> Vec<(Id, SocketAddr)> {
273 let mut ids: Vec<_> = self
278 .rt
279 .peer_list
280 .iter()
281 .map(|(&candidate_id, &candidate_meta)| (candidate_id, candidate_meta.listening_addr))
282 .collect();
283 ids.sort_unstable_by_key(|(candidate_id, _)| candidate_id.log2_distance(id));
285 ids.truncate(k);
286
287 ids
288 }
289
290 pub fn process_message<S: Clone, T: ProcessData<S>>(
295 &mut self,
296 state: S,
297 message: Message,
298 conn_addr: SocketAddr,
299 ) -> Option<Response> {
300 let id = match self.peer_id(conn_addr) {
301 Some(id) => id,
302 None => return None,
303 };
304
305 if let Some(peer_meta) = self.rt.peer_list.get_mut(&id) {
307 peer_meta.last_seen = Some(OffsetDateTime::now_utc())
308 }
309
310 match message {
311 Message::Ping(ping) => {
312 let pong = self.process_ping(ping);
313 Some(Response::Unicast(Message::Pong(pong)))
314 }
315 Message::Pong(pong) => {
316 self.process_pong(pong);
317 None
318 }
319 Message::FindKNodes(find_k_nodes) => {
320 let k_nodes = self.process_find_k_nodes(find_k_nodes);
321 Some(Response::Unicast(Message::KNodes(k_nodes)))
322 }
323 Message::KNodes(k_nodes) => {
324 self.process_k_nodes(k_nodes);
325 None
326 }
327 Message::Chunk(chunk) => {
328 if let Some(broadcast) = self.process_chunk::<S, T>(state, chunk) {
329 let broadcast = broadcast
330 .into_iter()
331 .map(|(addr, message)| (addr, Message::Chunk(message)))
332 .collect();
333
334 Some(Response::Broadcast(broadcast))
335 } else {
336 None
337 }
338 }
339 }
340 }
341
342 fn process_ping(&mut self, ping: Ping) -> Pong {
343 Pong {
346 nonce: ping.nonce,
347 id: self.local_id(),
348 }
349 }
350
351 fn process_pong(&mut self, _pong: Pong) {
352 }
355
356 fn process_find_k_nodes(&self, find_k_nodes: FindKNodes) -> KNodes {
357 let k_closest_nodes = self.find_k_closest(&find_k_nodes.id, self.rt.k as usize);
358
359 KNodes {
360 nonce: find_k_nodes.nonce,
361 nodes: k_closest_nodes,
362 }
363 }
364
365 fn process_k_nodes(&mut self, k_nodes: KNodes) {
366 for (id, listening_addr) in k_nodes.nodes {
368 self.insert(id, listening_addr);
369 }
370
371 }
374
375 fn process_chunk<S: Clone, T: ProcessData<S>>(
376 &self,
377 state: S,
378 chunk: Chunk,
379 ) -> Option<Vec<(SocketAddr, Chunk)>> {
380 let data = chunk.data.clone();
382 let data_as_t: T = chunk.data.into();
383 let is_kosher = data_as_t.verify_data(state.clone());
384
385 if !is_kosher {
389 return None;
390 }
391
392 data_as_t.process_data(state);
393
394 self.select_broadcast_peers(chunk.height).map(|v| {
396 v.iter()
397 .map(|(height, addr)| {
398 (
399 *addr,
400 Chunk {
401 nonce: chunk.nonce,
403 height: *height,
404 data: data.clone(),
405 },
406 )
407 })
408 .collect()
409 })
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 fn localhost_with_port(port: u16) -> SocketAddr {
420 format!("127.0.0.1:{}", port).parse().unwrap()
421 }
422
423 #[test]
424 fn peer_id_not_present() {
425 let router = TcpRouter::new(Id::from_u16(0), 1, 20);
426 assert!(router.peer_id(localhost_with_port(0)).is_none());
427 }
428
429 #[test]
430 fn peer_id_is_present() {
431 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
432
433 let id = Id::from_u16(1);
434 let addr = localhost_with_port(1);
435
436 assert!(router.insert(id, addr));
437 assert_eq!(router.peer_id(addr), Some(id));
438 }
439
440 #[test]
441 fn insert() {
442 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
443 assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
445 assert!(router.insert(Id::from_u16(2), localhost_with_port(2)));
447 assert!(router.insert(Id::from_u16(3), localhost_with_port(3)));
449 }
450
451 #[test]
452 fn insert_defaults() {
453 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
454
455 let id = Id::from_u16(1);
456 let addr = localhost_with_port(1);
457
458 assert!(router.insert(id, addr));
459
460 let meta = router.peer_meta(&id).unwrap();
461 assert_eq!(meta.listening_addr, addr);
462 assert_eq!(meta.conn_addr, None);
463 assert_eq!(meta.conn_state, ConnState::Disconnected);
464 assert_eq!(meta.last_seen, None);
465 }
466
467 #[test]
468 fn insert_self() {
469 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
470 assert!(!router.insert(router.local_id(), localhost_with_port(0)));
472 }
473
474 #[test]
475 fn insert_duplicate() {
476 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
477 assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
479 assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
480 }
481
482 #[test]
483 fn insert_duplicate_updates_listening_addr() {
484 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
485
486 let id = Id::from_u16(1);
487 let addr = localhost_with_port(1);
488 assert!(router.insert(id, addr));
489 assert_eq!(router.peer_meta(&id).unwrap().listening_addr, addr);
490
491 let id = Id::from_u16(1);
494 let addr = localhost_with_port(2);
495 assert!(router.insert(id, addr));
496 assert_eq!(router.peer_meta(&id).unwrap().listening_addr, addr);
497 }
498
499 #[test]
500 fn set_connected() {
501 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
503
504 let addr = localhost_with_port(1);
506 let id = Id::from_u16(1);
507 router.insert(id, addr);
508 assert!(router.set_connected(id, addr));
509
510 let addr = localhost_with_port(2);
512 let id = Id::from_u16(2);
513 router.insert(id, addr);
514 assert!(router.set_connected(id, addr));
515
516 let addr = localhost_with_port(3);
518 let id = Id::from_u16(3);
519 router.insert(id, addr);
520 assert!(!router.set_connected(id, addr));
521 }
522
523 #[test]
524 fn set_connected_self() {
525 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
526 assert!(!router.set_connected(router.local_id(), localhost_with_port(0)));
527 }
528
529 #[test]
530 fn set_disconnected() {
531 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
532 let id = Id::from_u16(1);
533 let addr = localhost_with_port(1);
534 assert!(router.insert(id, addr));
535 assert!(router.set_connected(id, addr));
536 assert!(router.set_disconnected(addr));
537 }
538
539 #[test]
540 fn set_disconnected_non_existant() {
541 let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
542 assert!(!router.set_disconnected(localhost_with_port(0)));
543 }
544
545 #[test]
546 fn find_k_closest() {
547 let mut router = TcpRouter::new(Id::from_u16(0), 5, 20);
548
549 let peers: Vec<(Id, SocketAddr)> = (1..=5)
551 .into_iter()
552 .map(|i| (Id::from_u16(i), localhost_with_port(i)))
553 .collect();
554
555 for peer in &peers {
556 assert!(router.insert(peer.0, peer.1));
557 assert!(router.set_connected(peer.0, peer.1));
558 }
559
560 let k = 3;
561 let k_closest = router.find_k_closest(&router.local_id(), k);
562
563 assert_eq!(k_closest.len(), 3);
564 assert!(k_closest.contains(&peers[0]));
565 assert!(k_closest.contains(&peers[1]));
566 assert!(k_closest.contains(&peers[2]));
567 }
568
569 #[test]
570 fn find_k_closest_empty() {
571 let router = TcpRouter::new(Id::from_u16(0), 5, 20);
572 let k = 3;
573 let k_closest = router.find_k_closest(&router.local_id(), k);
574 assert_eq!(k_closest.len(), 0);
575 }
576
577 #[test]
578 fn select_broadcast_peers() {
579 let mut router = TcpRouter::new(Id::from_u16(0), 5, 20);
580
581 let peers: Vec<(Id, SocketAddr)> = (1..=5)
583 .into_iter()
584 .map(|i| (Id::from_u16(i), localhost_with_port(i)))
585 .collect();
586
587 for peer in peers {
588 assert!(router.insert(peer.0, peer.1));
590 assert!(router.set_connected(peer.0, peer.1));
591 }
592
593 let h = 0;
597 assert!(router.select_broadcast_peers(h).is_none());
598
599 let h = 1;
600 let selected_peers = router.select_broadcast_peers(h).unwrap();
602 assert_eq!(selected_peers.len(), 1);
603 assert_eq!(selected_peers[0].0, 0);
605
606 }
608}