use std::{
collections::HashMap,
fmt::{Debug, Formatter},
hash::Hash,
net::IpAddr,
};
use ts_bart::{RouteModification, RoutingTable, RoutingTableExt};
use ts_control::{Node, StableNodeId};
use ts_keys::{DiscoPublicKey, NodePublicKey};
use ts_transport::PeerId;
mod private {
use super::*;
pub trait Sealed {}
impl Sealed for PeerId {}
impl Sealed for NodePublicKey {}
impl Sealed for DiscoPublicKey {}
impl Sealed for StableNodeId {}
impl Sealed for ts_control::NodeId {}
impl Sealed for PeerName {}
impl Sealed for &str {}
impl Sealed for IpAddr {}
impl Sealed for ipnet::IpNet {}
}
pub trait IndexedField: Debug + private::Sealed {
fn lookup(&self, db: &PeerDb) -> Option<PeerId>;
}
type Index<T> = HashMap<T, PeerId>;
type PeerName = String;
fn canon_name(name: &str) -> String {
name.strip_suffix('.').unwrap_or(name).to_ascii_lowercase()
}
#[derive(Default, Clone)]
pub struct PeerDb {
peers: HashMap<PeerId, Node>,
index_state: IndexState,
next_id: u32,
}
impl Debug for PeerDb {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.peers.fmt(f)
}
}
#[derive(Default, Clone)]
struct IndexState {
nk_idx: Index<NodePublicKey>,
disco_idx: Index<DiscoPublicKey>,
stableid_idx: Index<StableNodeId>,
control_idx: Index<ts_control::NodeId>,
name_idx: Index<PeerName>,
ip_idx: ts_bart::Table<PeerId>,
route_idx: ts_bart::Table<smallvec::SmallVec<[PeerId; 2]>>,
}
impl PeerDb {
pub fn upsert(&mut self, new: &Node) -> PeerId {
let id = self
.index_state
.stableid_idx
.get(&new.stable_id)
.copied()
.unwrap_or_else(|| {
let id = self.next_id;
self.next_id += 1;
PeerId(id)
});
let old = self.peers.get(&id);
if old.is_some_and(|x| x == new) {
return id;
}
maybe_update_idx(new, old, |x| &x.node_key, &mut self.index_state.nk_idx, id);
maybe_update_idx(
new,
old,
|x| &x.stable_id,
&mut self.index_state.stableid_idx,
id,
);
maybe_update_idx(new, old, |x| &x.id, &mut self.index_state.control_idx, id);
maybe_update(
new,
old,
|x| &x.disco_key,
&mut self.index_state.disco_idx,
|old, idx| {
if let Some(key) = &old.disco_key {
if idx.get(key).is_some_and(|&x| x == id) {
idx.remove(key);
}
}
},
|new, idx| {
if let Some(key) = &new.disco_key {
idx.insert(*key, id);
}
},
);
maybe_update(
new,
old,
|x| (&x.hostname, &x.tailnet),
&mut self.index_state.name_idx,
|old, idx| {
let old_hostname = canon_name(&old.hostname);
if idx.get(&old_hostname).is_some_and(|&x| x == id) {
idx.remove(&old_hostname);
}
if let Some(fqdn) = old.fqdn_opt(false) {
let k = canon_name(&fqdn);
if idx.get(&k).is_some_and(|&x| x == id) {
idx.remove(&k);
}
}
},
|new, idx| {
idx.insert(canon_name(&new.hostname), id);
if let Some(fqdn) = new.fqdn_opt(false) {
idx.insert(canon_name(&fqdn), id);
}
},
);
maybe_update(
new,
old,
|x| &x.tailnet_address,
&mut self.index_state.ip_idx,
|old, idx| {
let ipv4: ipnet::IpNet = old.tailnet_address.ipv4.into();
let ipv6: ipnet::IpNet = old.tailnet_address.ipv6.into();
if idx.lookup_prefix_exact(ipv4).is_some_and(|&x| x == id) {
idx.remove(ipv4);
}
if idx.lookup_prefix_exact(ipv6).is_some_and(|&x| x == id) {
idx.remove(ipv6);
}
},
|new, idx| {
idx.insert(new.tailnet_address.ipv4.into(), id);
idx.insert(new.tailnet_address.ipv6.into(), id);
},
);
maybe_update(
new,
old,
|x| &x.accepted_routes,
&mut self.index_state,
|old, idx| {
for &route in &old.accepted_routes {
idx.remove_route(route, id);
}
},
|new, idx| {
for &route in &new.accepted_routes {
idx.route_idx.modify(route, |val| {
if let Some(val) = val {
val.push(id);
return RouteModification::Noop;
}
RouteModification::Insert(smallvec::smallvec![id])
});
}
},
);
self.peers.insert(id, new.clone());
id
}
pub fn remove(&mut self, field: &dyn IndexedField) -> Option<(PeerId, Node)> {
let id = field.lookup(self)?;
let node = self.peers.remove(&id)?;
self.index_state.remove(id, &node);
Some((id, node))
}
pub fn get(&self, field: &dyn IndexedField) -> Option<(PeerId, &Node)> {
let id = field.lookup(self)?;
let peer = self.peers.get(&id)?;
Some((id, peer))
}
pub fn get_route(&self, route: ipnet::IpNet) -> impl Iterator<Item = (PeerId, &Node)> {
self.index_state
.route_idx
.lookup_prefix(route)
.into_iter()
.flat_map(|x| x.iter())
.map(|&id| (id, self.peers.get(&id).unwrap()))
}
pub fn has(&self, field: &dyn IndexedField) -> Option<PeerId> {
field.lookup(self)
}
pub const fn peers(&self) -> &HashMap<PeerId, Node> {
&self.peers
}
pub fn retain(&mut self, mut predicate: impl FnMut(PeerId, &Node) -> bool) {
self.peers.retain(|&id, node| {
let retain = predicate(id, node);
if !retain {
self.index_state.remove(id, node);
}
retain
});
}
}
impl IndexState {
fn remove(&mut self, id: PeerId, node: &Node) {
self.nk_idx.remove(&node.node_key);
self.stableid_idx.remove(&node.stable_id);
self.control_idx.remove(&node.id);
self.ip_idx.remove(node.tailnet_address.ipv4.into());
self.ip_idx.remove(node.tailnet_address.ipv6.into());
let hostname = canon_name(&node.hostname);
if self.name_idx.get(&hostname).is_some_and(|&x| x == id) {
self.name_idx.remove(&hostname);
}
if let Some(fqdn) = node.fqdn_opt(false) {
self.name_idx.remove(&canon_name(&fqdn));
}
for route in &node.accepted_routes {
self.remove_route(*route, id);
}
if let Some(disco) = &node.disco_key {
self.disco_idx.remove(disco);
}
}
fn remove_route(&mut self, route: ipnet::IpNet, id: PeerId) {
self.route_idx.modify(route, |val| match val {
Some(val) => {
let mut some_matched = false;
val.retain(|&mut x| {
let ids_match = x == id;
if ids_match {
some_matched = true;
}
!ids_match
});
assert!(some_matched);
if val.is_empty() {
RouteModification::Remove
} else {
RouteModification::Noop
}
}
None => RouteModification::Noop,
});
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.nk_idx.is_empty()
&& self.stableid_idx.is_empty()
&& self.control_idx.is_empty()
&& self.ip_idx.size() == 0
&& self.name_idx.is_empty()
&& self.route_idx.size() == 0
&& self.disco_idx.is_empty()
}
}
fn maybe_update<'n, T, Idx>(
new: &'n Node,
old: Option<&'n Node>,
accessor: impl Fn(&'n Node) -> T,
idx: &mut Idx,
mut remove: impl FnMut(&'n Node, &mut Idx),
mut insert: impl FnMut(&'n Node, &mut Idx),
) where
T: PartialEq + 'n,
{
match old {
Some(old) if accessor(old) == accessor(new) => {
return;
}
Some(x) => {
remove(x, idx);
}
None => {}
}
insert(new, idx)
}
fn maybe_update_idx<T>(
new: &Node,
old: Option<&Node>,
accessor: impl Fn(&Node) -> &T,
idx: &mut Index<T>,
new_id: PeerId,
) where
T: Eq + Hash + Clone,
{
maybe_update(
new,
old,
&accessor,
idx,
|old, idx| {
if idx.get(accessor(old)).is_some_and(|&x| x == new_id) {
idx.remove(accessor(old));
}
},
|new, idx| {
idx.insert(accessor(new).clone(), new_id);
},
)
}
impl IndexedField for PeerId {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
if db.peers.contains_key(self) {
Some(*self)
} else {
None
}
}
}
impl IndexedField for NodePublicKey {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.nk_idx.get(self).copied()
}
}
impl IndexedField for DiscoPublicKey {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.disco_idx.get(self).copied()
}
}
impl IndexedField for StableNodeId {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.stableid_idx.get(self).copied()
}
}
impl IndexedField for ts_control::NodeId {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.control_idx.get(self).copied()
}
}
impl IndexedField for PeerName {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.name_idx.get(&canon_name(self)).copied()
}
}
impl IndexedField for &str {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.name_idx.get(&canon_name(self)).copied()
}
}
impl IndexedField for IpAddr {
fn lookup(&self, db: &PeerDb) -> Option<PeerId> {
db.index_state.ip_idx.lookup(*self).copied()
}
}
#[cfg(test)]
mod test {
use std::{
collections::{HashMap, HashSet},
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
num::NonZeroU32,
};
use proptest::{
collection::{hash_set, vec},
prelude::any,
strategy::Strategy,
};
use rand::{
RngExt,
distr::{Alphanumeric, SampleString},
};
use ts_control::TailnetAddress;
use super::*;
fn rand_string(rng: &mut dyn rand::Rng, max_len: usize) -> String {
let len = rng.random_range(1..max_len);
Alphanumeric.sample_string(rng, len)
}
fn rand_route(rng: &mut dyn rand::Rng) -> ipnet::IpNet {
if rng.random::<bool>() {
let ip = rand_ipv4(rng);
ipnet::Ipv4Net::new(ip, rand::random_range(0..=32))
.unwrap()
.trunc()
.into()
} else {
let ip = rand_ipv6(rng);
ipnet::Ipv6Net::new(ip, rand::random_range(0..=128))
.unwrap()
.trunc()
.into()
}
}
fn rand_ipv4(rng: &mut dyn rand::Rng) -> Ipv4Addr {
Ipv4Addr::from_octets(rng.random::<[u8; 4]>())
}
fn rand_ipv6(rng: &mut dyn rand::Rng) -> Ipv6Addr {
Ipv6Addr::from_segments(rng.random::<[u16; 8]>())
}
fn rand_node() -> Node {
let mut rng = rand::rng();
Node {
stable_id: StableNodeId(rand_string(&mut rng, 32)),
tailnet_address: TailnetAddress {
ipv4: rand_ipv4(&mut rng).into(),
ipv6: rand_ipv6(&mut rng).into(),
},
node_key: rng.random::<[u8; 32]>().into(),
key_signature: vec![],
disco_key: rng
.random::<bool>()
.then_some(rng.random::<[u8; 32]>().into()),
machine_key: rng
.random::<bool>()
.then_some(rng.random::<[u8; 32]>().into()),
id: rng.random(),
accepted_routes: (0..rng.random_range(0..32))
.map(|_| rand_route(&mut rng))
.collect(),
hostname: rand_string(&mut rng, 32),
user_id: rng.random(),
tailnet: rng.random::<bool>().then_some(rand_string(&mut rng, 32)),
node_key_expiry: None,
online: None,
last_seen: None,
underlay_addresses: vec![],
derp_region: rng
.random::<bool>()
.then_some(ts_derp::RegionId(rng.random())),
tags: (0..rng.random_range(0..8))
.map(|_| rand_string(&mut rng, 32))
.collect(),
cap: Default::default(),
cap_map: Default::default(),
peerapi_port: None,
peerapi_dns_proxy: false,
is_wireguard_only: false,
exit_node_dns_resolvers: vec![],
peer_relay: false,
service_vips: Default::default(),
}
}
fn validate_indices(db: &PeerDb, node: &Node, id: PeerId) {
let ipv4 = IpAddr::from(node.tailnet_address.ipv4.addr());
let ipv6 = IpAddr::from(node.tailnet_address.ipv6.addr());
let fqdn = node.fqdn_opt(false);
let mut keys: Vec<&dyn IndexedField> =
vec![&id, &node.node_key, &node.stable_id, &node.id, &ipv4, &ipv6];
if let Some(disco) = &node.disco_key {
keys.push(disco);
}
if let Some(fqdn) = &fqdn {
keys.push(fqdn);
}
for k in keys {
let lookup_id = k.lookup(db).unwrap();
assert_eq!(lookup_id, id, "wrong id for key {k:?}");
let (lookup_id, lookup_node) = db.get(k).unwrap();
assert_eq!(lookup_id, id, "wrong id for key {k:?}");
assert_eq!(lookup_node, node, "wrong node for key {k:?}");
}
node.hostname.lookup(db).unwrap();
for &route in &node.accepted_routes {
let routes = db.get_route(route).collect::<Vec<_>>();
assert!(!routes.is_empty());
for (found_id, found_node) in routes {
if found_id == id {
assert_eq!(found_node, node);
break;
}
let has_subset = found_node
.accepted_routes
.iter()
.any(|found_route| route.contains(found_route));
assert!(has_subset);
}
}
}
fn assert_has_routes_exact(db: &PeerDb, node: &Node, id: PeerId) {
for &route in &node.accepted_routes {
let match_exists = db
.get_route(route)
.any(|(found_id, found_node)| found_id == id && found_node == node);
assert!(match_exists);
}
}
#[test]
fn test_indices() {
let mut db = PeerDb::default();
let node = rand_node();
let id = db.upsert(&node);
validate_indices(&db, &node, id);
assert_has_routes_exact(&db, &node, id);
}
#[test]
fn test_names() {
let mut db = PeerDb::default();
let node1 = Node {
hostname: "test".to_string(),
tailnet: Some("ts.net".to_string()),
..rand_node()
};
let node2 = Node {
hostname: "test".to_string(),
tailnet: Some("ts2.net".to_string()),
..rand_node()
};
let node3 = Node {
hostname: "test".to_string(),
tailnet: None,
..rand_node()
};
let id1 = db.upsert(&node1);
let id2 = db.upsert(&node2);
let id3 = db.upsert(&node3);
let nodes = [(id1, &node1), (id2, &node2), (id3, &node3)];
for (id, node) in &nodes {
validate_indices(&db, node, *id);
}
let (id, node) = db.get(&"test").unwrap();
assert!(nodes.iter().any(|(x, _node)| *x == id));
for &(x, curnode) in &nodes {
if x == id {
assert_eq!(node, curnode);
} else {
assert_ne!(node, curnode);
}
}
let (id, node) = db.get(&"test.ts.net").unwrap();
assert_eq!(id, id1);
assert_eq!(node, &node1);
let (id, node) = db.get(&"test.ts2.net").unwrap();
assert_eq!(id, id2);
assert_eq!(node, &node2);
}
#[test]
fn test_name_lookup_is_canonicalized() {
let mut db = PeerDb::default();
let node = Node {
hostname: "MixedCase".to_string(),
tailnet: Some("Tail-Scale.ts.net".to_string()),
..rand_node()
};
let id = db.upsert(&node);
assert_eq!(db.get(&"mixedcase").unwrap().0, id);
assert_eq!(db.get(&"MIXEDCASE").unwrap().0, id);
assert_eq!(db.get(&"mixedcase.tail-scale.ts.net").unwrap().0, id);
assert_eq!(db.get(&"MixedCase.Tail-Scale.TS.NET").unwrap().0, id);
assert_eq!(db.get(&"mixedcase.tail-scale.ts.net.").unwrap().0, id);
db.remove(&id);
assert!(db.get(&"mixedcase").is_none());
assert!(db.get(&"mixedcase.tail-scale.ts.net").is_none());
assert!(db.index_state.is_empty());
}
#[test]
fn disco_key_reassigned_across_peers_no_panic() {
let mut db = PeerDb::default();
let disco: DiscoPublicKey = [7u8; 32].into();
let node_a = Node {
disco_key: Some(disco),
..rand_node()
};
let id_a = db.upsert(&node_a);
let node_b = Node {
disco_key: Some(disco),
..rand_node()
};
let id_b = db.upsert(&node_b);
assert_ne!(id_a, id_b);
let node_a2 = Node {
disco_key: None,
..node_a.clone()
};
let id_a2 = db.upsert(&node_a2);
assert_eq!(id_a, id_a2);
assert_eq!(disco.lookup(&db), Some(id_b));
}
#[test]
fn ip_reassigned_across_peers_no_panic() {
let mut db = PeerDb::default();
let shared = TailnetAddress {
ipv4: Ipv4Addr::new(100, 64, 0, 1).into(),
ipv6: Ipv6Addr::new(0xfd7a, 0, 0, 0, 0, 0, 0, 1).into(),
};
let node_a = Node {
tailnet_address: shared.clone(),
..rand_node()
};
let id_a = db.upsert(&node_a);
let node_b = Node {
tailnet_address: shared.clone(),
..rand_node()
};
let id_b = db.upsert(&node_b);
assert_ne!(id_a, id_b);
let node_a2 = Node {
tailnet_address: TailnetAddress {
ipv4: Ipv4Addr::new(100, 64, 0, 2).into(),
ipv6: Ipv6Addr::new(0xfd7a, 0, 0, 0, 0, 0, 0, 2).into(),
},
..node_a.clone()
};
let id_a2 = db.upsert(&node_a2);
assert_eq!(id_a, id_a2);
assert_eq!(
IpAddr::from(Ipv4Addr::new(100, 64, 0, 1)).lookup(&db),
Some(id_b)
);
assert_eq!(
IpAddr::from(Ipv4Addr::new(100, 64, 0, 2)).lookup(&db),
Some(id_a)
);
}
#[test]
fn node_key_or_stableid_churn_no_panic() {
let mut db = PeerDb::default();
let key: NodePublicKey = [9u8; 32].into();
let node_a = Node {
node_key: key,
..rand_node()
};
let id_a = db.upsert(&node_a);
let node_b = Node {
node_key: key,
..rand_node()
};
let id_b = db.upsert(&node_b);
assert_ne!(id_a, id_b);
let node_a2 = Node {
node_key: [10u8; 32].into(),
..node_a.clone()
};
let id_a2 = db.upsert(&node_a2);
assert_eq!(id_a, id_a2);
assert_eq!(key.lookup(&db), Some(id_b));
assert_eq!(NodePublicKey::from([10u8; 32]).lookup(&db), Some(id_a));
}
proptest::prop_compose! {
fn ipv4net()(
addr: Ipv4Addr,
pfx in 0u8..=32,
) -> ipnet::Ipv4Net {
ipnet::Ipv4Net::new(addr, pfx).unwrap().trunc()
}
}
proptest::prop_compose! {
fn ipv6net()(
addr: Ipv6Addr,
pfx in 0u8..=32,
) -> ipnet::Ipv6Net {
ipnet::Ipv6Net::new(addr, pfx).unwrap().trunc()
}
}
fn ipnet() -> impl Strategy<Value = ipnet::IpNet> {
proptest::prop_oneof![
ipv4net().prop_map(ipnet::IpNet::from),
ipv6net().prop_map(ipnet::IpNet::from)
]
}
proptest::prop_compose! {
fn domain_segment()(
seg in "[a-z][a-z0-9]*"
) -> String {
seg
}
}
proptest::prop_compose! {
fn domain(max_count: usize)(
segs in proptest::collection::vec(domain_segment(), 0..max_count)
) -> String {
segs.join(".")
}
}
type Key = [u8; 32];
proptest::prop_compose! {
fn nodes(n: usize)(
id in hash_set(any::<i64>(), n),
stable_id in hash_set(".+", n),
tags in vec(hash_set(".+", 0..32), n),
accepted_routes in vec(hash_set(ipnet(), 0..32), n),
node_key in hash_set(any::<Key>(), n),
machine_key in vec(any::<Option<Key>>(), n),
disco_key in vec(any::<Option<Key>>(), n),
ipv4 in hash_set(any::<Ipv4Addr>(), n),
ipv6 in hash_set(any::<Ipv6Addr>(), n),
name in hash_set(domain_segment(), n),
tailnet in vec(domain(5), n),
has_tailnet in vec(any::<bool>(), n),
derp_region in vec(any::<Option<NonZeroU32>>(), n),
underlay_addrs in vec(any::<HashSet<SocketAddr>>(), n),
) -> Vec<Node> {
itertools::izip![
id,
stable_id,
tags,
accepted_routes,
node_key,
machine_key,
disco_key,
ipv4,
ipv6,
name,
tailnet,
has_tailnet,
derp_region,
underlay_addrs,
].map(|(
id,
stable_id,
tags,
mut accepted_routes,
node_key,
machine_key,
disco_key,
ipv4,
ipv6,
name,
tailnet,
has_tailnet,
derp_region,
underlay_addrs,
)| {
accepted_routes.insert(ipnet::Ipv4Net::from(ipv4).into());
accepted_routes.insert(ipnet::Ipv6Net::from(ipv6).into());
Node {
id,
stable_id: StableNodeId(stable_id),
hostname: name,
user_id: 0,
tailnet: has_tailnet.then_some(tailnet),
node_key: node_key.into(),
key_signature: vec![],
disco_key: disco_key.map(Into::into),
machine_key: machine_key.map(Into::into),
node_key_expiry: None,
online: None,
last_seen: None,
tailnet_address: TailnetAddress {
ipv4: ipv4.into(),
ipv6: ipv6.into(),
},
tags: tags.into_iter().collect(),
derp_region: derp_region.map(ts_derp::RegionId),
accepted_routes: accepted_routes.into_iter().collect(),
underlay_addresses: underlay_addrs.into_iter().collect(),
cap: Default::default(),
cap_map: Default::default(),
peerapi_port: None,
peerapi_dns_proxy: false,
is_wireguard_only: false,
exit_node_dns_resolvers: vec![],
peer_relay: false,
service_vips: Default::default(),
}
})
.collect()
}
}
proptest::proptest! {
#[test]
fn prop_one_node_indices(mut nodes in nodes(1)) {
let node = nodes.pop().unwrap();
let mut db = PeerDb::default();
let id = db.upsert(&node);
validate_indices(&db, &node, id);
assert_has_routes_exact(&db, &node, id);
}
#[test]
fn prop_many_nodes_indexed(nodes in nodes(16)) {
let mut db = PeerDb::default();
let mut nodes_by_id = HashMap::new();
for node in &nodes {
let id = db.upsert(node);
nodes_by_id.insert(id, node.clone());
}
for (id, node) in &nodes_by_id {
validate_indices(&db, node, *id);
}
}
#[test]
fn prop_remove(nodes in nodes(16)) {
let mut db = PeerDb::default();
let mut ids = vec![];
for node in &nodes {
ids.push((db.upsert(node), node));
}
for (id, node) in ids {
let (removed_id, removed_node) = db.remove(&id).unwrap();
proptest::prop_assert_eq!(removed_id, id);
proptest::prop_assert_eq!(&removed_node, node);
}
proptest::prop_assert!(db.peers.is_empty());
proptest::prop_assert!(db.index_state.is_empty());
}
}
}