kadmium/tcp/
router.rs

1use std::{
2    cmp::Ordering,
3    collections::{HashMap, HashSet},
4    net::SocketAddr,
5};
6
7use rand::{seq::IteratorRandom, thread_rng, Fill};
8use time::OffsetDateTime;
9
10use crate::core::{
11    id::Id,
12    message::{Chunk, FindKNodes, KNodes, Message, Ping, Pong, Response},
13    routing_table::RoutingTable,
14    traits::ProcessData,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ConnState {
19    Connected,
20    Disconnected,
21}
22
23#[derive(Debug, Clone, Copy)]
24pub struct TcpMeta {
25    pub(crate) listening_addr: SocketAddr,
26    pub(crate) conn_addr: Option<SocketAddr>,
27    pub(crate) conn_state: ConnState,
28    pub(crate) last_seen: Option<OffsetDateTime>,
29}
30
31impl TcpMeta {
32    fn new(
33        listening_addr: SocketAddr,
34        conn_addr: Option<SocketAddr>,
35        conn_state: ConnState,
36        last_seen: Option<OffsetDateTime>,
37    ) -> Self {
38        Self {
39            listening_addr,
40            conn_addr,
41            conn_state,
42            last_seen,
43        }
44    }
45}
46
47type ConnAddr = SocketAddr;
48
49/// The core router implementation.
50#[derive(Debug, Clone)]
51pub struct TcpRouter {
52    /// The routing table backing this router.
53    rt: RoutingTable<ConnAddr, TcpMeta>,
54}
55
56impl Default for TcpRouter {
57    fn default() -> Self {
58        let mut rng = thread_rng();
59        let mut bytes = [0u8; Id::BYTES];
60        debug_assert!(bytes.try_fill(&mut rng).is_ok());
61
62        Self {
63            rt: RoutingTable {
64                local_id: Id::new(bytes),
65                max_bucket_size: 20,
66                k: 20,
67                buckets: HashMap::new(),
68                peer_list: HashMap::new(),
69                id_list: HashMap::new(),
70            },
71        }
72    }
73}
74
75impl TcpRouter {
76    /// Creates a new router.
77    pub fn new(local_id: Id, max_bucket_size: u8, k: u8) -> Self {
78        Self {
79            rt: RoutingTable {
80                local_id,
81                max_bucket_size,
82                k,
83                buckets: HashMap::new(),
84                peer_list: HashMap::new(),
85                id_list: HashMap::new(),
86            },
87        }
88    }
89
90    #[cfg(feature = "sync")]
91    pub(crate) fn routing_table(&self) -> &RoutingTable<ConnAddr, TcpMeta> {
92        &self.rt
93    }
94
95    /// Returns this router's local identifier.
96    pub fn local_id(&self) -> Id {
97        self.rt.local_id
98    }
99
100    /// Returns the identifier corresponding to the address, if it exists.
101    pub fn peer_id(&self, addr: SocketAddr) -> Option<Id> {
102        self.rt.id_list.get(&addr).copied()
103    }
104
105    /// Returns the peer's metadata if it exists.
106    pub fn peer_meta(&self, id: &Id) -> Option<&TcpMeta> {
107        self.rt.peer_list.get(id)
108    }
109
110    /// Returns `true` if the record exists already or was inserted, `false` if an attempt was made to
111    /// insert our local identifier.
112    pub fn insert(&mut self, id: Id, listening_addr: SocketAddr) -> bool {
113        // Buckets should only contain connected peers. The other structures should track
114        // connection state.
115
116        // An insert can happen in two instances:
117        //
118        // 1. the peer initiated the connection (should only be inserted into the bucket if there
119        //    is space, this requires a later call to set_connected).
120        // 2. the peer was included in a list from another peer (should be inserted as
121        //    disconnected unless it is already in the list and is connected).
122        //
123        // Insert all peers as disconnected initially. The caller can then check if the
124        // bucket has space and if so initiates a connection in case 1, or accepts the connection
125        // in case 2.
126        //
127        // If a peer exists already, we update the address information without reseting the
128        // connecton state. The node wrapping this implementation should make sure to call
129        // `set_disconnected` if a connection is closed or dropped.
130
131        if id == self.rt.local_id {
132            return false;
133        }
134
135        self.rt
136            .peer_list
137            .entry(id)
138            .and_modify(|meta| {
139                meta.listening_addr = listening_addr;
140            })
141            .or_insert_with(|| TcpMeta::new(listening_addr, None, ConnState::Disconnected, None));
142
143        self.rt.id_list.insert(listening_addr, id);
144
145        true
146    }
147
148    /// Returns whether or not there is space in the bucket corresponding to the identifier and the
149    /// appropriate bucket index if there is.
150    pub fn can_connect(&mut self, id: Id) -> (bool, Option<u32>) {
151        let i = match self.local_id().log2_distance(&id) {
152            Some(i) => i,
153            None => return (false, None),
154        };
155
156        // TODO: check if identifier is already in use?
157
158        let bucket = self.rt.buckets.entry(i).or_insert_with(HashSet::new);
159        match bucket.len().cmp(&self.rt.max_bucket_size.into()) {
160            Ordering::Less => {
161                // Bucket still has space. Signal the value could be inserted into the bucket (once
162                // the connection is succesful).
163                (true, Some(i))
164            }
165            Ordering::Equal => {
166                // Bucket is full. Signal the value can't currently be inserted into the bucket.
167                (false, None)
168            }
169            Ordering::Greater => {
170                // Bucket is over capacity, this should never happen.
171                unreachable!()
172            }
173        }
174    }
175
176    /// Sets the peer as connected on the router, returning `false` if there is no room to
177    /// connect the peer.
178    pub fn set_connected(&mut self, id: Id, conn_addr: SocketAddr) -> bool {
179        match self.can_connect(id) {
180            (true, Some(i)) => {
181                if let (Some(peer_meta), Some(bucket)) =
182                    (self.rt.peer_list.get_mut(&id), self.rt.buckets.get_mut(&i))
183                {
184                    // If the bucket insert returns `false`, it means the id is already in the bucket and the
185                    // peer is connected.
186                    let _res = bucket.insert(id);
187                    debug_assert!(_res);
188
189                    self.rt.id_list.insert(conn_addr, id);
190                    peer_meta.conn_addr = Some(conn_addr);
191                    peer_meta.conn_state = ConnState::Connected;
192                    peer_meta.last_seen = Some(OffsetDateTime::now_utc());
193                }
194
195                true
196            }
197
198            _ => false,
199        }
200    }
201
202    /// Removes an identifier from the buckets, sets the peer to disconnected and returns `true` on
203    /// success, `false` otherwise.
204    pub fn set_disconnected(&mut self, conn_addr: SocketAddr) -> bool {
205        let id = match self.peer_id(conn_addr) {
206            Some(id) => id,
207            None => return false,
208        };
209
210        let i = self
211            .local_id()
212            .log2_distance(&id)
213            .expect("self can't have an identifier in the peer list");
214
215        if let (Some(peer_meta), Some(bucket)) =
216            (self.rt.peer_list.get_mut(&id), self.rt.buckets.get_mut(&i))
217        {
218            let bucket_res = bucket.remove(&id);
219            debug_assert!(bucket_res);
220
221            // Remove the entry from the identifier list as the addr is likely to change when a
222            // peer reconnects later (also this means we only have one collection tracking
223            // disconnected peers for simplicity).
224            let id_list_res = self
225                .rt
226                .id_list
227                .remove(&peer_meta.conn_addr.expect("conn_addr must be present"));
228            debug_assert!(id_list_res.is_some());
229
230            peer_meta.conn_addr = None;
231            peer_meta.conn_state = ConnState::Disconnected;
232
233            return bucket_res && id_list_res.is_some();
234        }
235
236        false
237    }
238
239    /// Selects the broadcast peers for a particular height, returns `None` if the broadcast
240    /// shouldn't continue any further.
241    pub fn select_broadcast_peers(&self, height: u32) -> Option<Vec<(u32, SocketAddr)>> {
242        let mut rng = thread_rng();
243
244        // Don't broadcast any further.
245        if height == 0 {
246            return None;
247        }
248
249        let mut selected_peers = vec![];
250        for h in 0..height {
251            if let Some(bucket) = self.rt.buckets.get(&h) {
252                // Choose one peer at random per bucket.
253                if let Some(id) = bucket.iter().choose(&mut rng) {
254                    // The value should exist as the peer is in the bucket.
255                    let peer_meta = self.rt.peer_list.get(id);
256                    debug_assert!(peer_meta.is_some());
257                    debug_assert_eq!(peer_meta.unwrap().conn_state, ConnState::Connected);
258                    // Return the connection address, not the listening address as we need to
259                    // broadcast through the connection, not to the listener.
260                    debug_assert!(peer_meta.unwrap().conn_addr.is_some());
261                    let addr = peer_meta.unwrap().conn_addr.unwrap();
262
263                    selected_peers.push((h, addr))
264                }
265            }
266        }
267
268        Some(selected_peers)
269    }
270
271    /// Returns the K closest nodes to the identifier.
272    fn find_k_closest(&self, id: &Id, k: usize) -> Vec<(Id, SocketAddr)> {
273        // There is a total order over the id-space, though we take the log2 of the XOR distance,
274        // and so peers within a bucket are considered at the same distance. We use an unstable
275        // sort as we don't care if items at the same distance are reordered (and it is usually
276        // faster).
277        let mut ids: Vec<_> = self
278            .rt
279            .peer_list
280            .iter()
281            .map(|(&candidate_id, &candidate_meta)| (candidate_id, candidate_meta.listening_addr))
282            .collect();
283        // TODO: bench and consider sort_by_cached_key.
284        ids.sort_unstable_by_key(|(candidate_id, _)| candidate_id.log2_distance(id));
285        ids.truncate(k);
286
287        ids
288    }
289
290    // MESSAGE PROCESSING
291
292    /// Processes a peer's message. If it is a query, an appropriate response is returned to
293    /// be sent.
294    pub fn process_message<S: Clone, T: ProcessData<S>>(
295        &mut self,
296        state: S,
297        message: Message,
298        conn_addr: SocketAddr,
299    ) -> Option<Response> {
300        let id = match self.peer_id(conn_addr) {
301            Some(id) => id,
302            None => return None,
303        };
304
305        // Update the peer's last seen timestamp.
306        if let Some(peer_meta) = self.rt.peer_list.get_mut(&id) {
307            peer_meta.last_seen = Some(OffsetDateTime::now_utc())
308        }
309
310        match message {
311            Message::Ping(ping) => {
312                let pong = self.process_ping(ping);
313                Some(Response::Unicast(Message::Pong(pong)))
314            }
315            Message::Pong(pong) => {
316                self.process_pong(pong);
317                None
318            }
319            Message::FindKNodes(find_k_nodes) => {
320                let k_nodes = self.process_find_k_nodes(find_k_nodes);
321                Some(Response::Unicast(Message::KNodes(k_nodes)))
322            }
323            Message::KNodes(k_nodes) => {
324                self.process_k_nodes(k_nodes);
325                None
326            }
327            Message::Chunk(chunk) => {
328                if let Some(broadcast) = self.process_chunk::<S, T>(state, chunk) {
329                    let broadcast = broadcast
330                        .into_iter()
331                        .map(|(addr, message)| (addr, Message::Chunk(message)))
332                        .collect();
333
334                    Some(Response::Broadcast(broadcast))
335                } else {
336                    None
337                }
338            }
339        }
340    }
341
342    fn process_ping(&mut self, ping: Ping) -> Pong {
343        // Prepare a response, send back the same nonce so the original sender can identify the
344        // request the response corresponds to.
345        Pong {
346            nonce: ping.nonce,
347            id: self.local_id(),
348        }
349    }
350
351    fn process_pong(&mut self, _pong: Pong) {
352        // TODO: how should latency factor into the broadcast logic? Perhaps keep a table with the
353        // message nonces for latency calculation?
354    }
355
356    fn process_find_k_nodes(&self, find_k_nodes: FindKNodes) -> KNodes {
357        let k_closest_nodes = self.find_k_closest(&find_k_nodes.id, self.rt.k as usize);
358
359        KNodes {
360            nonce: find_k_nodes.nonce,
361            nodes: k_closest_nodes,
362        }
363    }
364
365    fn process_k_nodes(&mut self, k_nodes: KNodes) {
366        // Save the new peer information.
367        for (id, listening_addr) in k_nodes.nodes {
368            self.insert(id, listening_addr);
369        }
370
371        // TODO: work out who to connect with to continue the recursive search. Should this be
372        // continual or only when bootstrapping the network?
373    }
374
375    fn process_chunk<S: Clone, T: ProcessData<S>>(
376        &self,
377        state: S,
378        chunk: Chunk,
379    ) -> Option<Vec<(SocketAddr, Chunk)>> {
380        // Cheap as the backing storage is shared amongst instances.
381        let data = chunk.data.clone();
382        let data_as_t: T = chunk.data.into();
383        let is_kosher = data_as_t.verify_data(state.clone());
384
385        // This is where the buckets come in handy. When a node processes a chunk message, it
386        // selects peers in buckets ]h, 0] and propagates the CHUNK message. If h = 0, no
387        // propagation occurs.
388        if !is_kosher {
389            return None;
390        }
391
392        data_as_t.process_data(state);
393
394        // TODO: return the wrapped data as well as the peers to propagate to.
395        self.select_broadcast_peers(chunk.height).map(|v| {
396            v.iter()
397                .map(|(height, addr)| {
398                    (
399                        *addr,
400                        Chunk {
401                            // TODO: work out if this is a bad idea.
402                            nonce: chunk.nonce,
403                            height: *height,
404                            data: data.clone(),
405                        },
406                    )
407                })
408                .collect()
409        })
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    // Produces a local address from the supplied port.
418    // TODO: consider testing with random ports and IDs.
419    fn localhost_with_port(port: u16) -> SocketAddr {
420        format!("127.0.0.1:{}", port).parse().unwrap()
421    }
422
423    #[test]
424    fn peer_id_not_present() {
425        let router = TcpRouter::new(Id::from_u16(0), 1, 20);
426        assert!(router.peer_id(localhost_with_port(0)).is_none());
427    }
428
429    #[test]
430    fn peer_id_is_present() {
431        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
432
433        let id = Id::from_u16(1);
434        let addr = localhost_with_port(1);
435
436        assert!(router.insert(id, addr));
437        assert_eq!(router.peer_id(addr), Some(id));
438    }
439
440    #[test]
441    fn insert() {
442        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
443        // ... 0001 -> bucket i = 0
444        assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
445        // ... 0010 -> bucket i = 1
446        assert!(router.insert(Id::from_u16(2), localhost_with_port(2)));
447        // ... 0011 -> bucket i = 1
448        assert!(router.insert(Id::from_u16(3), localhost_with_port(3)));
449    }
450
451    #[test]
452    fn insert_defaults() {
453        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
454
455        let id = Id::from_u16(1);
456        let addr = localhost_with_port(1);
457
458        assert!(router.insert(id, addr));
459
460        let meta = router.peer_meta(&id).unwrap();
461        assert_eq!(meta.listening_addr, addr);
462        assert_eq!(meta.conn_addr, None);
463        assert_eq!(meta.conn_state, ConnState::Disconnected);
464        assert_eq!(meta.last_seen, None);
465    }
466
467    #[test]
468    fn insert_self() {
469        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
470        // Attempt to insert our local id.
471        assert!(!router.insert(router.local_id(), localhost_with_port(0)));
472    }
473
474    #[test]
475    fn insert_duplicate() {
476        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
477        // The double insert will still return true.
478        assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
479        assert!(router.insert(Id::from_u16(1), localhost_with_port(1)));
480    }
481
482    #[test]
483    fn insert_duplicate_updates_listening_addr() {
484        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
485
486        let id = Id::from_u16(1);
487        let addr = localhost_with_port(1);
488        assert!(router.insert(id, addr));
489        assert_eq!(router.peer_meta(&id).unwrap().listening_addr, addr);
490
491        // Inserting the same identifier with a different address should result in the new address
492        // getting stored.
493        let id = Id::from_u16(1);
494        let addr = localhost_with_port(2);
495        assert!(router.insert(id, addr));
496        assert_eq!(router.peer_meta(&id).unwrap().listening_addr, addr);
497    }
498
499    #[test]
500    fn set_connected() {
501        // Set the max bucket size to a low value so we can easily test when it's full.
502        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
503
504        // ... 0001 -> bucket i = 0
505        let addr = localhost_with_port(1);
506        let id = Id::from_u16(1);
507        router.insert(id, addr);
508        assert!(router.set_connected(id, addr));
509
510        // ... 0010 -> bucket i = 1
511        let addr = localhost_with_port(2);
512        let id = Id::from_u16(2);
513        router.insert(id, addr);
514        assert!(router.set_connected(id, addr));
515
516        // ... 0011 -> bucket i = 1
517        let addr = localhost_with_port(3);
518        let id = Id::from_u16(3);
519        router.insert(id, addr);
520        assert!(!router.set_connected(id, addr));
521    }
522
523    #[test]
524    fn set_connected_self() {
525        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
526        assert!(!router.set_connected(router.local_id(), localhost_with_port(0)));
527    }
528
529    #[test]
530    fn set_disconnected() {
531        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
532        let id = Id::from_u16(1);
533        let addr = localhost_with_port(1);
534        assert!(router.insert(id, addr));
535        assert!(router.set_connected(id, addr));
536        assert!(router.set_disconnected(addr));
537    }
538
539    #[test]
540    fn set_disconnected_non_existant() {
541        let mut router = TcpRouter::new(Id::from_u16(0), 1, 20);
542        assert!(!router.set_disconnected(localhost_with_port(0)));
543    }
544
545    #[test]
546    fn find_k_closest() {
547        let mut router = TcpRouter::new(Id::from_u16(0), 5, 20);
548
549        // Generate 5 identifiers and addressses.
550        let peers: Vec<(Id, SocketAddr)> = (1..=5)
551            .into_iter()
552            .map(|i| (Id::from_u16(i), localhost_with_port(i)))
553            .collect();
554
555        for peer in &peers {
556            assert!(router.insert(peer.0, peer.1));
557            assert!(router.set_connected(peer.0, peer.1));
558        }
559
560        let k = 3;
561        let k_closest = router.find_k_closest(&router.local_id(), k);
562
563        assert_eq!(k_closest.len(), 3);
564        assert!(k_closest.contains(&peers[0]));
565        assert!(k_closest.contains(&peers[1]));
566        assert!(k_closest.contains(&peers[2]));
567    }
568
569    #[test]
570    fn find_k_closest_empty() {
571        let router = TcpRouter::new(Id::from_u16(0), 5, 20);
572        let k = 3;
573        let k_closest = router.find_k_closest(&router.local_id(), k);
574        assert_eq!(k_closest.len(), 0);
575    }
576
577    #[test]
578    fn select_broadcast_peers() {
579        let mut router = TcpRouter::new(Id::from_u16(0), 5, 20);
580
581        // Generate 5 identifiers and addressses.
582        let peers: Vec<(Id, SocketAddr)> = (1..=5)
583            .into_iter()
584            .map(|i| (Id::from_u16(i), localhost_with_port(i)))
585            .collect();
586
587        for peer in peers {
588            // Conn address is listening address (all peers received the connections).
589            assert!(router.insert(peer.0, peer.1));
590            assert!(router.set_connected(peer.0, peer.1));
591        }
592
593        // Find the random addresses in each bucket.
594
595        // If the height is 0, we are the last node in the recursion, don't broadcast.
596        let h = 0;
597        assert!(router.select_broadcast_peers(h).is_none());
598
599        let h = 1;
600        // Should be present.
601        let selected_peers = router.select_broadcast_peers(h).unwrap();
602        assert_eq!(selected_peers.len(), 1);
603        // Height for selected peer should be 0.
604        assert_eq!(selected_peers[0].0, 0);
605
606        // TODO: Bucket at index 0 should contain the id corresponding to the address.
607    }
608}