radicle/node/address/
store.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::net::IpAddr;
3use std::num::TryFromIntError;
4use std::str::FromStr;
5
6use localtime::LocalTime;
7use sqlite as sql;
8use thiserror::Error;
9
10use crate::node;
11use crate::node::address::{AddressType, KnownAddress, Node, Source};
12use crate::node::UserAgent;
13use crate::node::{Address, Alias, AliasError, AliasStore, Database, NodeId, Penalty, Severity};
14use crate::prelude::Timestamp;
15use crate::sql::transaction;
16
17#[derive(Error, Debug)]
18pub enum Error {
19    /// An Internal error.
20    #[error("internal error: {0}")]
21    Internal(#[from] sql::Error),
22    #[error("alias error: {0}")]
23    InvalidAlias(#[from] AliasError),
24    #[error("node id error: {0}")]
25    Node(#[from] crypto::PublicKeyError),
26    #[error("integer conversion error: {0}")]
27    TryFromInt(#[from] TryFromIntError),
28    /// No rows returned in query result.
29    #[error("no rows returned")]
30    NoRows,
31}
32
33/// An entry returned by the store.
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct AddressEntry {
36    /// Node ID.
37    pub node: NodeId,
38    /// Node protocol version.
39    pub version: u8,
40    /// Node penalty.
41    pub penalty: Penalty,
42    /// Node address.
43    pub address: KnownAddress,
44}
45
46/// Address store.
47///
48/// Used to store node addresses and metadata.
49pub trait Store {
50    /// Get the information we have about a node.
51    fn get(&self, id: &NodeId) -> Result<Option<Node>, Error>;
52    /// Get the addresses of a node.
53    fn addresses_of(&self, node: &NodeId) -> Result<Vec<KnownAddress>, Error>;
54    /// Insert a node with associated addresses into the store.
55    ///
56    /// Returns `true` if the node or addresses were updated, and `false` otherwise.
57    fn insert(
58        &mut self,
59        node: &NodeId,
60        version: u8,
61        features: node::Features,
62        alias: &Alias,
63        pow: u32,
64        agent: &UserAgent,
65        timestamp: Timestamp,
66        addrs: impl IntoIterator<Item = KnownAddress>,
67    ) -> Result<bool, Error>;
68    /// Remove a node from the store.
69    fn remove(&mut self, id: &NodeId) -> Result<bool, Error>;
70    /// Returns the number of addresses.
71    fn len(&self) -> Result<usize, Error>;
72    /// Return the number of nodes.
73    fn nodes(&self) -> Result<usize, Error>;
74    /// Returns true if there are no addresses.
75    fn is_empty(&self) -> Result<bool, Error> {
76        self.len().map(|l| l == 0)
77    }
78    /// Check if an address is banned. Also returns `true` if the node this address belongs
79    /// to is banned.
80    fn is_addr_banned(&self, addr: &Address) -> Result<bool, Error>;
81    /// Check if an IP is banned.
82    fn is_ip_banned(&self, ip: IpAddr) -> Result<bool, Error>;
83    /// Get the address entries in the store.
84    fn entries(&self) -> Result<Box<dyn Iterator<Item = AddressEntry>>, Error>;
85    /// Mark a node as attempted at a certain time.
86    fn attempted(&self, nid: &NodeId, addr: &Address, time: Timestamp) -> Result<(), Error>;
87    /// Mark a node as successfully connected at a certain time.
88    fn connected(&self, nid: &NodeId, addr: &Address, time: Timestamp) -> Result<(), Error>;
89    /// Record a node IP address and connection time.
90    fn record_ip(&self, nid: &NodeId, ip: IpAddr, time: Timestamp) -> Result<(), Error>;
91    /// Mark a node as disconnected.
92    fn disconnected(
93        &mut self,
94        nid: &NodeId,
95        addr: &Address,
96        severity: Severity,
97    ) -> Result<(), Error>;
98}
99
100pub trait StoreExt {
101    type NodeAlias<'a>: Iterator<Item = Result<(NodeId, Alias), Error>> + 'a
102    where
103        Self: 'a;
104
105    fn nodes_by_alias<'a>(&'a self, alias: &Alias) -> Result<Self::NodeAlias<'a>, Error>;
106}
107
108impl Store for Database {
109    fn get(&self, node: &NodeId) -> Result<Option<Node>, Error> {
110        let mut stmt = self.db.prepare(
111            "SELECT version, features, alias, pow, penalty, banned, agent, timestamp
112             FROM nodes
113             WHERE id = ?",
114        )?;
115        stmt.bind((1, node))?;
116
117        if let Some(Ok(row)) = stmt.into_iter().next() {
118            let version = row.read::<i64, _>("version").try_into()?;
119            let features = row.read::<node::Features, _>("features");
120            let alias = Alias::from_str(row.read::<&str, _>("alias"))?;
121            let timestamp = row.read::<Timestamp, _>("timestamp");
122            let pow = row.read::<i64, _>("pow") as u32;
123            let agent = row.read::<UserAgent, _>("agent");
124            let penalty = row.read::<i64, _>("penalty").min(u8::MAX as i64);
125            let penalty = Penalty(penalty as u8);
126            let banned = row.read::<i64, _>("banned").is_positive();
127            let addrs = self.addresses_of(node)?;
128
129            Ok(Some(Node {
130                version,
131                features,
132                alias,
133                pow,
134                agent,
135                timestamp,
136                penalty,
137                addrs,
138                banned,
139            }))
140        } else {
141            Ok(None)
142        }
143    }
144
145    fn is_addr_banned(&self, addr: &Address) -> Result<bool, Error> {
146        let mut stmt = self.db.prepare(
147            "SELECT a.banned, n.banned
148             FROM addresses AS a
149             JOIN nodes AS n ON a.node = n.id
150             WHERE value = ?1 AND type = ?2",
151        )?;
152        stmt.bind((1, addr))?;
153        stmt.bind((2, AddressType::from(addr)))?;
154
155        if let Some(row) = stmt.into_iter().next() {
156            let row = row?;
157            let addr_banned = row.read::<i64, _>(0).is_positive();
158            let node_banned = row.read::<i64, _>(1).is_positive();
159
160            Ok(node_banned || addr_banned)
161        } else {
162            Ok(false)
163        }
164    }
165
166    fn is_ip_banned(&self, ip: IpAddr) -> Result<bool, Error> {
167        let mut stmt = self.db.prepare(
168            "SELECT banned
169             FROM ips
170             WHERE ip = ?1 AND banned > 0",
171        )?;
172        stmt.bind((1, ip.to_string().as_str()))?;
173
174        Ok(stmt.into_iter().next().is_some())
175    }
176
177    fn addresses_of(&self, node: &NodeId) -> Result<Vec<KnownAddress>, Error> {
178        let mut addrs = Vec::new();
179        let mut stmt = self.db.prepare(
180            "SELECT type, value, source, last_attempt, last_success, banned FROM addresses WHERE node = ?",
181        )?;
182        stmt.bind((1, node))?;
183
184        for row in stmt.into_iter() {
185            let row = row?;
186            let _typ = row.read::<AddressType, _>("type");
187            let addr = row.read::<Address, _>("value");
188            let source = row.read::<Source, _>("source");
189            let last_attempt = row
190                .read::<Option<i64>, _>("last_attempt")
191                .map(|t| LocalTime::from_millis(t as u128));
192            let last_success = row
193                .read::<Option<i64>, _>("last_success")
194                .map(|t| LocalTime::from_millis(t as u128));
195            let banned = row.read::<i64, _>("banned").is_positive();
196
197            addrs.push(KnownAddress {
198                addr,
199                source,
200                last_success,
201                last_attempt,
202                banned,
203            });
204        }
205        Ok(addrs)
206    }
207
208    fn len(&self) -> Result<usize, Error> {
209        let row = self
210            .db
211            .prepare("SELECT COUNT(*) FROM addresses")?
212            .into_iter()
213            .next()
214            .ok_or(Error::NoRows)??;
215        let count = row.read::<i64, _>(0) as usize;
216
217        Ok(count)
218    }
219
220    fn nodes(&self) -> Result<usize, Error> {
221        let row = self
222            .db
223            .prepare("SELECT COUNT(*) FROM nodes")?
224            .into_iter()
225            .next()
226            .ok_or(Error::NoRows)??;
227        let count = row.read::<i64, _>(0) as usize;
228
229        Ok(count)
230    }
231
232    fn insert(
233        &mut self,
234        node: &NodeId,
235        version: u8,
236        features: node::Features,
237        alias: &Alias,
238        pow: u32,
239        agent: &UserAgent,
240        timestamp: Timestamp,
241        addrs: impl IntoIterator<Item = KnownAddress>,
242    ) -> Result<bool, Error> {
243        transaction(&self.db, move |db| {
244            let mut stmt = db.prepare(
245                "INSERT INTO nodes (id, version, features, alias, pow, agent, timestamp)
246                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
247                 ON CONFLICT DO UPDATE
248                 SET version = ?2, features = ?3, alias = ?4, pow = ?5, agent = ?6, timestamp = ?7
249                 WHERE timestamp < ?7",
250            )?;
251
252            stmt.bind((1, node))?;
253            stmt.bind((2, version as i64))?;
254            stmt.bind((3, features))?;
255            stmt.bind((4, alias.as_str()))?;
256            stmt.bind((5, pow as i64))?;
257            stmt.bind((6, agent.as_str()))?;
258            stmt.bind((7, &timestamp))?;
259            stmt.next()?;
260
261            for addr in addrs {
262                let mut stmt = db.prepare(
263                    "INSERT INTO addresses (node, type, value, source, timestamp)
264                     VALUES (?1, ?2, ?3, ?4, ?5)
265                     ON CONFLICT DO UPDATE
266                     SET timestamp = ?5
267                     WHERE timestamp < ?5",
268                )?;
269                stmt.bind((1, node))?;
270                stmt.bind((2, AddressType::from(&addr.addr)))?;
271                stmt.bind((3, &addr.addr))?;
272                stmt.bind((4, addr.source))?;
273                stmt.bind((5, &timestamp))?;
274                stmt.next()?;
275            }
276            Ok::<_, Error>(db.change_count() > 0)
277        })
278    }
279
280    fn remove(&mut self, node: &NodeId) -> Result<bool, Error> {
281        let mut stmt = self.db.prepare("DELETE FROM nodes WHERE id = ?1")?;
282
283        stmt.bind((1, node))?;
284        stmt.next()?;
285
286        Ok(self.db.change_count() > 0)
287    }
288
289    fn entries(&self) -> Result<Box<dyn Iterator<Item = AddressEntry>>, Error> {
290        let mut stmt = self
291            .db
292            .prepare(
293                "SELECT a.node, a.type, a.value, a.source, a.last_success, a.last_attempt, a.banned, n.version, n.penalty
294                 FROM addresses AS a
295                 JOIN nodes AS n ON a.node = n.id
296                 ORDER BY n.penalty ASC, n.id ASC",
297            )?
298            .into_iter();
299        let mut entries = Vec::new();
300
301        while let Some(Ok(row)) = stmt.next() {
302            let node = row.read::<NodeId, _>("node");
303            let _typ = row.read::<AddressType, _>("type");
304            let addr = row.read::<Address, _>("value");
305            let source = row.read::<Source, _>("source");
306            let last_success = row.read::<Option<i64>, _>("last_success");
307            let last_attempt = row.read::<Option<i64>, _>("last_attempt");
308            let last_success = last_success.map(|t| LocalTime::from_millis(t as u128));
309            let last_attempt = last_attempt.map(|t| LocalTime::from_millis(t as u128));
310            let version = row.read::<i64, _>("version").try_into()?;
311            let banned = row.read::<i64, _>("banned").is_positive();
312            let penalty = row.read::<i64, _>("penalty");
313            let penalty = Penalty(penalty as u8); // Clamped at `u8::MAX`.
314
315            entries.push(AddressEntry {
316                node,
317                version,
318                penalty,
319                address: KnownAddress {
320                    addr,
321                    source,
322                    last_success,
323                    last_attempt,
324                    banned,
325                },
326            });
327        }
328        Ok(Box::new(entries.into_iter()))
329    }
330
331    fn attempted(&self, nid: &NodeId, addr: &Address, time: Timestamp) -> Result<(), Error> {
332        let mut stmt = self.db.prepare(
333            "UPDATE `addresses`
334             SET last_attempt = ?1
335             WHERE node = ?2
336             AND type = ?3
337             AND value = ?4",
338        )?;
339
340        stmt.bind((1, &time))?;
341        stmt.bind((2, nid))?;
342        stmt.bind((3, AddressType::from(addr)))?;
343        stmt.bind((4, addr))?;
344        stmt.next()?;
345
346        Ok(())
347    }
348
349    fn connected(&self, nid: &NodeId, addr: &Address, time: Timestamp) -> Result<(), Error> {
350        transaction(&self.db, |db| {
351            let mut stmt = db.prepare(
352                "UPDATE `addresses`
353                 SET last_success = ?1
354                 WHERE node = ?2
355                 AND type = ?3
356                 AND value = ?4",
357            )?;
358
359            stmt.bind((1, &time))?;
360            stmt.bind((2, nid))?;
361            stmt.bind((3, AddressType::from(addr)))?;
362            stmt.bind((4, addr))?;
363            stmt.next()?;
364
365            // Reduce penalty by half on successful connect.
366            let mut stmt = db.prepare("UPDATE `nodes` SET penalty = penalty / 2 WHERE id = ?1")?;
367
368            stmt.bind((1, nid))?;
369            stmt.next()?;
370
371            Ok(())
372        })
373    }
374
375    fn record_ip(&self, nid: &NodeId, ip: IpAddr, time: Timestamp) -> Result<(), Error> {
376        let mut stmt = self.db.prepare(
377            "INSERT INTO ips (ip, node, last_attempt)
378             VALUES (?1, ?2, ?3)
379             ON CONFLICT DO UPDATE
380             SET last_attempt = ?3
381             WHERE last_attempt < ?3",
382        )?;
383        stmt.bind((1, ip.to_string().as_str()))?;
384        stmt.bind((2, nid))?;
385        stmt.bind((3, &time))?;
386        stmt.next()?;
387
388        Ok(())
389    }
390
391    fn disconnected(
392        &mut self,
393        nid: &NodeId,
394        addr: &Address,
395        severity: Severity,
396    ) -> Result<(), Error> {
397        transaction(&self.db, |db| {
398            let mut stmt = self.db.prepare(
399                "UPDATE `nodes`
400                 SET penalty = penalty + ?2
401                 WHERE id = ?1",
402            )?;
403            stmt.bind((1, nid))?;
404            stmt.bind((2, severity as i64))?;
405            stmt.next()?;
406
407            // If the ban threshold is reached, we ban the node and its addresses.
408            let node = self.get(nid)?.ok_or(Error::NoRows)?;
409            if node.penalty.is_ban_threshold_reached() {
410                let mut stmt = db.prepare("UPDATE `nodes` SET banned = 1 WHERE id = ?1")?;
411                stmt.bind((1, nid))?;
412                stmt.next()?;
413
414                let mut stmt = db.prepare("UPDATE `addresses` SET banned = 1 WHERE value = ?1")?;
415                stmt.bind((1, addr))?;
416                stmt.next()?;
417
418                let mut stmt = db.prepare("UPDATE `ips` SET banned = 1 WHERE node = ?1")?;
419                stmt.bind((1, nid))?;
420                stmt.next()?;
421            }
422            Ok(())
423        })
424    }
425}
426
427pub struct NodeAliasIter<'a> {
428    inner: sql::CursorWithOwnership<'a>,
429}
430
431impl NodeAliasIter<'_> {
432    fn parse_row(row: sql::Row) -> Result<(NodeId, Alias), Error> {
433        let nid = row.try_read::<NodeId, _>("id")?;
434        let alias = row.try_read::<&str, _>("alias")?.parse()?;
435        Ok((nid, alias))
436    }
437}
438
439impl Iterator for NodeAliasIter<'_> {
440    type Item = Result<(NodeId, Alias), Error>;
441
442    fn next(&mut self) -> Option<Self::Item> {
443        let row = self.inner.next()?;
444        Some(row.map_err(Error::from).and_then(Self::parse_row))
445    }
446}
447
448impl StoreExt for Database {
449    type NodeAlias<'a>
450        = NodeAliasIter<'a>
451    where
452        Self: 'a;
453
454    fn nodes_by_alias<'a>(&'a self, alias: &Alias) -> Result<Self::NodeAlias<'a>, Error> {
455        let mut stmt = self.db.prepare(
456            "SELECT id, alias
457             FROM nodes
458             WHERE UPPER(alias) LIKE ?",
459        )?;
460        stmt.bind((
461            1,
462            sql::Value::String(format!("%{}%", alias.as_str().to_uppercase())),
463        ))?;
464        Ok(NodeAliasIter {
465            inner: stmt.into_iter(),
466        })
467    }
468}
469
470impl<T> AliasStore for T
471where
472    T: Store + StoreExt,
473{
474    /// Retrieve `alias` of given node.
475    /// Calls `Self::get` under the hood.
476    fn alias(&self, nid: &NodeId) -> Option<Alias> {
477        self.get(nid)
478            .map(|node| node.map(|n| n.alias))
479            .unwrap_or(None)
480    }
481
482    fn reverse_lookup(&self, alias: &Alias) -> BTreeMap<Alias, BTreeSet<NodeId>> {
483        let Ok(iter) = self.nodes_by_alias(alias) else {
484            return BTreeMap::new();
485        };
486        iter.flatten()
487            .fold(BTreeMap::new(), |mut result, (node, alias)| {
488                let nodes = result.entry(alias).or_default();
489                nodes.insert(node);
490                result
491            })
492    }
493}
494
495impl TryFrom<&sql::Value> for Source {
496    type Error = sql::Error;
497
498    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
499        let err = sql::Error {
500            code: None,
501            message: Some("sql: invalid source".to_owned()),
502        };
503        match value {
504            sql::Value::String(s) => match s.as_str() {
505                "bootstrap" => Ok(Source::Bootstrap),
506                "peer" => Ok(Source::Peer),
507                "imported" => Ok(Source::Imported),
508                _ => Err(err),
509            },
510            _ => Err(err),
511        }
512    }
513}
514
515impl sql::BindableWithIndex for Source {
516    fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
517        match self {
518            Self::Bootstrap => "bootstrap".bind(stmt, i),
519            Self::Peer => "peer".bind(stmt, i),
520            Self::Imported => "imported".bind(stmt, i),
521        }
522    }
523}
524
525impl TryFrom<&sql::Value> for AddressType {
526    type Error = sql::Error;
527
528    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
529        let err = sql::Error {
530            code: None,
531            message: Some("sql: invalid address type".to_owned()),
532        };
533        match value {
534            sql::Value::String(s) => match s.as_str() {
535                "ipv4" => Ok(AddressType::Ipv4),
536                "ipv6" => Ok(AddressType::Ipv6),
537                "dns" => Ok(AddressType::Dns),
538                "onion" => Ok(AddressType::Onion),
539                _ => Err(err),
540            },
541            _ => Err(err),
542        }
543    }
544}
545
546impl sql::BindableWithIndex for AddressType {
547    fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
548        match self {
549            Self::Ipv4 => "ipv4".bind(stmt, i),
550            Self::Ipv6 => "ipv6".bind(stmt, i),
551            Self::Dns => "dns".bind(stmt, i),
552            Self::Onion => "onion".bind(stmt, i),
553        }
554    }
555}
556
557#[cfg(test)]
558#[allow(clippy::unwrap_used)]
559mod test {
560    use std::net;
561
562    use super::*;
563    use crate::test::arbitrary;
564    use cyphernet::addr::NetAddr;
565    use localtime::LocalTime;
566
567    #[test]
568    fn test_empty() {
569        let tmp = tempfile::tempdir().unwrap();
570        let path = tmp.path().join("cache");
571        let cache = Database::open(path).unwrap();
572
573        assert!(cache.is_empty().unwrap());
574    }
575
576    #[test]
577    fn test_get_none() {
578        let alice = arbitrary::gen::<NodeId>(1);
579        let cache = Database::memory().unwrap();
580        let result = cache.get(&alice).unwrap();
581
582        assert!(result.is_none());
583    }
584
585    #[test]
586    fn test_remove_nothing() {
587        let alice = arbitrary::gen::<NodeId>(1);
588        let mut cache = Database::memory().unwrap();
589        let removed = cache.remove(&alice).unwrap();
590
591        assert!(!removed);
592    }
593
594    #[test]
595    fn test_alias() {
596        let alice = arbitrary::gen::<NodeId>(1);
597        let mut cache = Database::memory().unwrap();
598        let features = node::Features::SEED;
599        let timestamp = Timestamp::from(LocalTime::now());
600        let ua = UserAgent::default();
601
602        cache
603            .insert(
604                &alice,
605                1,
606                features,
607                &Alias::new("alice"),
608                16,
609                &ua,
610                timestamp,
611                [],
612            )
613            .unwrap();
614        let node = cache.get(&alice).unwrap().unwrap();
615        assert_eq!(node.alias.as_ref(), "alice");
616
617        cache
618            .insert(
619                &alice,
620                1,
621                features,
622                &Alias::new("bob"),
623                16,
624                &ua,
625                timestamp + 1,
626                [],
627            )
628            .unwrap();
629        let node = cache.get(&alice).unwrap().unwrap();
630        assert_eq!(node.alias.as_ref(), "bob");
631    }
632
633    #[test]
634    fn test_insert_and_get() {
635        let alice = arbitrary::gen::<NodeId>(1);
636        let mut cache = Database::memory().unwrap();
637        let version = 2;
638        let features = node::Features::SEED;
639        let timestamp = LocalTime::now().into();
640        let ua = UserAgent::default();
641
642        let ka = KnownAddress {
643            addr: net::SocketAddr::from(([4, 4, 4, 4], 8776)).into(),
644            source: Source::Peer,
645            last_success: None,
646            last_attempt: None,
647            banned: false,
648        };
649        let inserted = cache
650            .insert(
651                &alice,
652                version,
653                features,
654                &Alias::new("alice"),
655                16,
656                &ua,
657                timestamp,
658                [ka.clone()],
659            )
660            .unwrap();
661        assert!(inserted);
662
663        let node = cache.get(&alice).unwrap().unwrap();
664
665        assert_eq!(node.version, version);
666        assert_eq!(node.features, features);
667        assert_eq!(node.pow, 16);
668        assert_eq!(node.timestamp, timestamp);
669        assert_eq!(node.alias.as_ref(), "alice");
670        assert_eq!(node.addrs, vec![ka]);
671    }
672
673    #[test]
674    fn test_insert_duplicate() {
675        let alice = arbitrary::gen::<NodeId>(1);
676        let mut cache = Database::memory().unwrap();
677        let features = node::Features::SEED;
678        let timestamp = LocalTime::now().into();
679        let alias = Alias::new("alice");
680        let ua = UserAgent::default();
681
682        let ka = KnownAddress {
683            addr: net::SocketAddr::from(([4, 4, 4, 4], 8776)).into(),
684            source: Source::Peer,
685            last_success: None,
686            last_attempt: None,
687            banned: false,
688        };
689        let inserted = cache
690            .insert(&alice, 1, features, &alias, 0, &ua, timestamp, [ka.clone()])
691            .unwrap();
692        assert!(inserted);
693
694        let inserted = cache
695            .insert(&alice, 1, features, &alias, 0, &ua, timestamp, [ka])
696            .unwrap();
697        assert!(!inserted);
698
699        assert_eq!(cache.len().unwrap(), 1);
700    }
701
702    #[test]
703    fn test_insert_and_update() {
704        let alice = arbitrary::gen::<NodeId>(1);
705        let mut cache = Database::memory().unwrap();
706        let timestamp = LocalTime::now().into();
707        let features = node::Features::SEED;
708        let ua1 = UserAgent::default();
709        let ua2 = UserAgent::default();
710        let alias1 = Alias::new("alice");
711        let alias2 = Alias::new("~alice~");
712        let ka = KnownAddress {
713            addr: net::SocketAddr::from(([4, 4, 4, 4], 8776)).into(),
714            source: Source::Peer,
715            last_success: None,
716            last_attempt: None,
717            banned: false,
718        };
719
720        let updated = cache
721            .insert(
722                &alice,
723                1,
724                features,
725                &alias1,
726                0,
727                &ua1,
728                timestamp,
729                [ka.clone()],
730            )
731            .unwrap();
732        assert!(updated);
733
734        let updated = cache
735            .insert(&alice, 1, features, &alias2, 0, &ua1, timestamp, [])
736            .unwrap();
737        assert!(!updated, "Can't update using the same timestamp");
738
739        let updated = cache
740            .insert(&alice, 1, features, &alias2, 0, &ua1, timestamp - 1, [])
741            .unwrap();
742        assert!(!updated, "Can't update using a smaller timestamp");
743
744        let node = cache.get(&alice).unwrap().unwrap();
745        assert_eq!(node.alias.as_ref(), "alice");
746        assert_eq!(node.timestamp, timestamp);
747        assert_eq!(node.pow, 0);
748
749        let updated = cache
750            .insert(&alice, 1, features, &alias2, 0, &ua2, timestamp + 1, [])
751            .unwrap();
752        assert!(updated, "Can update with a larger timestamp");
753
754        let updated = cache
755            .insert(
756                &alice,
757                1,
758                node::Features::NONE,
759                &alias2,
760                1,
761                &ua2,
762                timestamp + 2,
763                [],
764            )
765            .unwrap();
766        assert!(updated);
767
768        let node = cache.get(&alice).unwrap().unwrap();
769        assert_eq!(node.features, node::Features::NONE);
770        assert_eq!(node.alias.as_ref(), "~alice~");
771        assert_eq!(node.timestamp, timestamp + 2);
772        assert_eq!(node.pow, 1);
773        assert_eq!(node.addrs, vec![ka]);
774        assert_eq!(node.agent, ua2);
775    }
776
777    #[test]
778    fn test_insert_and_remove() {
779        let alice = arbitrary::gen::<NodeId>(1);
780        let bob = arbitrary::gen::<NodeId>(1);
781        let mut cache = Database::memory().unwrap();
782        let timestamp = LocalTime::now().into();
783        let ua = UserAgent::default();
784        let features = node::Features::SEED;
785        let alice_alias = Alias::new("alice");
786        let bob_alias = Alias::new("bob");
787
788        for addr in [
789            ([4, 4, 4, 4], 8776),
790            ([7, 7, 7, 7], 8776),
791            ([9, 9, 9, 9], 8776),
792        ] {
793            let ka = KnownAddress {
794                addr: net::SocketAddr::from(addr).into(),
795                source: Source::Peer,
796                last_success: None,
797                last_attempt: None,
798                banned: false,
799            };
800            cache
801                .insert(
802                    &alice,
803                    1,
804                    features,
805                    &alice_alias,
806                    0,
807                    &ua,
808                    timestamp,
809                    [ka.clone()],
810                )
811                .unwrap();
812            cache
813                .insert(&bob, 1, features, &bob_alias, 0, &ua, timestamp, [ka])
814                .unwrap();
815        }
816        assert_eq!(cache.len().unwrap(), 6);
817
818        let removed = cache.remove(&alice).unwrap();
819        assert!(removed);
820        assert_eq!(cache.len().unwrap(), 3);
821
822        let removed = cache.remove(&bob).unwrap();
823        assert!(removed);
824        assert_eq!(cache.len().unwrap(), 0);
825    }
826
827    #[test]
828    fn test_entries() {
829        let ids = arbitrary::vec::<NodeId>(16);
830        let mut rng = fastrand::Rng::new();
831        let mut cache = Database::memory().unwrap();
832        let mut expected = Vec::new();
833        let timestamp = LocalTime::now().into();
834        let ua = UserAgent::default();
835        let features = node::Features::SEED;
836        let alias = Alias::new("alice");
837
838        for id in ids {
839            let ip = rng.u32(..);
840            let addr = net::SocketAddr::from((net::Ipv4Addr::from(ip), rng.u16(..)));
841            let ka = KnownAddress {
842                addr: addr.into(),
843                source: Source::Bootstrap,
844                // TODO: Test times as well.
845                last_success: None,
846                last_attempt: None,
847                banned: false,
848            };
849            expected.push(AddressEntry {
850                node: id,
851                version: 3,
852                penalty: Penalty::default(),
853                address: ka.clone(),
854            });
855            cache
856                .insert(&id, 3, features, &alias, 0, &ua, timestamp, [ka])
857                .unwrap();
858        }
859
860        let mut actual = cache.entries().unwrap().collect::<Vec<_>>();
861
862        actual.sort_by_key(|ae| ae.node);
863        expected.sort_by_key(|ae| ae.node);
864
865        assert_eq!(cache.len().unwrap(), actual.len());
866        assert_eq!(actual, expected);
867    }
868
869    #[test]
870    fn test_disconnected() {
871        let alice = arbitrary::gen::<NodeId>(1);
872        let addr = arbitrary::gen::<Address>(1);
873        let mut cache = Database::memory().unwrap();
874        let features = node::Features::SEED;
875        let timestamp = Timestamp::from(LocalTime::now());
876        let ua = UserAgent::default();
877
878        cache
879            .insert(
880                &alice,
881                1,
882                features,
883                &Alias::new("alice"),
884                16,
885                &ua,
886                timestamp,
887                [],
888            )
889            .unwrap();
890        let node = cache.get(&alice).unwrap().unwrap();
891        assert_eq!(node.penalty, Penalty::default());
892
893        cache.disconnected(&alice, &addr, Severity::Low).unwrap();
894        let node = cache.get(&alice).unwrap().unwrap();
895        assert_eq!(node.penalty, Penalty::default());
896
897        cache.disconnected(&alice, &addr, Severity::Medium).unwrap();
898        let node = cache.get(&alice).unwrap().unwrap();
899        assert_eq!(node.penalty, Penalty(1));
900
901        cache.disconnected(&alice, &addr, Severity::High).unwrap();
902        let node = cache.get(&alice).unwrap().unwrap();
903        assert_eq!(node.penalty, Penalty(9));
904
905        cache.connected(&alice, &addr, timestamp + 1).unwrap();
906        let node = cache.get(&alice).unwrap().unwrap();
907        assert_eq!(node.penalty, Penalty(4));
908    }
909
910    #[test]
911    fn test_disconnected_ban() {
912        let alice = arbitrary::gen::<NodeId>(1);
913        let ua = UserAgent::default();
914        let ip1: net::Ipv4Addr = [8, 8, 8, 8].into();
915        let ip2: net::Ipv4Addr = [9, 9, 9, 9].into();
916        let ka1 = arbitrary::gen::<KnownAddress>(1);
917        let ka1 = KnownAddress {
918            addr: Address::from(NetAddr::new(ip1.into(), 8776)),
919            ..ka1
920        };
921        let ka2 = arbitrary::gen::<KnownAddress>(1);
922        let ka2 = KnownAddress {
923            addr: Address::from(NetAddr::new(ip2.into(), 8776)),
924            ..ka2
925        };
926        let mut db = Database::memory().unwrap();
927        let features = node::Features::SEED;
928        let timestamp = Timestamp::from(LocalTime::now());
929
930        db.insert(
931            &alice,
932            1,
933            features,
934            &Alias::new("alice"),
935            16,
936            &ua,
937            timestamp,
938            [ka1.clone(), ka2.clone()],
939        )
940        .unwrap();
941        db.record_ip(&alice, ip1.into(), timestamp).unwrap();
942        db.record_ip(&alice, ip2.into(), timestamp).unwrap();
943
944        let node = db.get(&alice).unwrap().unwrap();
945        assert_eq!(node.penalty, Penalty::default());
946
947        for _ in 0..7 {
948            db.disconnected(&alice, &ka1.addr, Severity::High).unwrap();
949            let node = db.get(&alice).unwrap().unwrap();
950
951            assert!(!node.penalty.is_ban_threshold_reached());
952            assert!(!node.banned);
953        }
954
955        db.disconnected(&alice, &ka1.addr, Severity::High).unwrap();
956        let node = db.get(&alice).unwrap().unwrap();
957
958        assert!(node.penalty.is_ban_threshold_reached());
959        assert!(node.banned);
960
961        for addr in node.addrs {
962            if addr.addr == ka1.addr {
963                assert!(addr.banned);
964            } else {
965                assert!(!addr.banned);
966            }
967        }
968        assert!(db.is_addr_banned(&ka1.addr).unwrap());
969        assert!(db.is_addr_banned(&ka2.addr).unwrap()); // Banned because node is banned.
970        assert!(db.is_ip_banned(ip1.into()).unwrap());
971        assert!(db.is_ip_banned(ip2.into()).unwrap());
972    }
973
974    #[test]
975    fn test_node_aliases() {
976        let mut db = Database::memory().unwrap();
977        let input = node::properties::AliasInput::new();
978        let (short, short_ids) = input.short();
979        let (long, long_ids) = input.long();
980        let features = node::Features::SEED;
981        let agent = UserAgent::default();
982        let timestamp = Timestamp::from(LocalTime::now());
983        let ka = arbitrary::gen::<KnownAddress>(1);
984
985        for id in short_ids {
986            db.insert(id, 1, features, short, 16, &agent, timestamp, [ka.clone()])
987                .unwrap();
988        }
989
990        for id in long_ids {
991            db.insert(id, 1, features, long, 16, &agent, timestamp, [ka.clone()])
992                .unwrap();
993        }
994
995        node::properties::test_reverse_lookup(&db, input)
996    }
997}