use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use iroh::EndpointId;
use iroh::endpoint::Connection;
use smol_str::SmolStr;
use crate::audit::AuditLog;
pub type FastDashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
#[derive(Clone)]
pub struct PeerTable {
v4: Arc<FastDashMap<Ipv4Addr, PeerEntry>>,
v6: Arc<FastDashMap<Ipv6Addr, PeerEntry>>,
audit: Option<Arc<AuditLog>>,
}
pub struct PeerEntry {
pub endpoint_id: EndpointId,
pub conns: HashMap<SmolStr, Connection>,
}
pub struct PeerRoute {
pub conn: Connection,
pub endpoint_id: EndpointId,
pub network: SmolStr,
}
impl PeerEntry {
fn route(&self) -> Option<PeerRoute> {
let (network, conn) = self.conns.iter().min_by(|a, b| a.0.cmp(b.0))?;
Some(PeerRoute {
conn: conn.clone(),
endpoint_id: self.endpoint_id,
network: network.clone(),
})
}
}
impl Default for PeerTable {
fn default() -> Self {
Self::new()
}
}
impl PeerTable {
pub fn new() -> Self {
Self {
v4: Arc::new(FastDashMap::default()),
v6: Arc::new(FastDashMap::default()),
audit: None,
}
}
pub fn with_audit(audit: Arc<AuditLog>) -> Self {
Self {
v4: Arc::new(FastDashMap::default()),
v6: Arc::new(FastDashMap::default()),
audit: Some(audit),
}
}
pub fn add(
&self,
ip: Ipv4Addr,
ipv6: Ipv6Addr,
conn: Connection,
endpoint_id: EndpointId,
network: &str,
) {
let net = SmolStr::new(network);
let newly_connected;
{
let mut e = self.v4.entry(ip).or_insert_with(|| PeerEntry {
endpoint_id,
conns: HashMap::new(),
});
e.endpoint_id = endpoint_id;
newly_connected = e.conns.insert(net.clone(), conn.clone()).is_none();
}
{
let mut e = self.v6.entry(ipv6).or_insert_with(|| PeerEntry {
endpoint_id,
conns: HashMap::new(),
});
e.endpoint_id = endpoint_id;
e.conns.insert(net, conn);
}
if newly_connected && let Some(audit) = &self.audit {
audit.log_connect(ip, &endpoint_id.to_string());
}
}
pub fn lookup_v4(&self, ip: &Ipv4Addr) -> Option<PeerRoute> {
self.v4.get(ip).and_then(|e| e.route())
}
pub fn lookup_v6(&self, ip: &Ipv6Addr) -> Option<PeerRoute> {
self.v6.get(ip).and_then(|e| e.route())
}
pub fn identity_and_networks(&self, ip: IpAddr) -> Option<(EndpointId, Vec<SmolStr>)> {
match ip {
IpAddr::V4(v4) => self
.v4
.get(&v4)
.map(|e| (e.endpoint_id, e.conns.keys().cloned().collect())),
IpAddr::V6(v6) => self
.v6
.get(&v6)
.map(|e| (e.endpoint_id, e.conns.keys().cloned().collect())),
}
}
pub fn remove(&self, ip: &Ipv4Addr, ipv6: &Ipv6Addr) {
let removed = self.v4.remove(ip);
self.v6.remove(ipv6);
if let (Some((_, entry)), Some(audit)) = (removed, &self.audit) {
audit.log_disconnect(*ip, &entry.endpoint_id.to_string());
}
}
pub fn remove_peer_from_network(&self, ip: &Ipv4Addr, ipv6: &Ipv6Addr, network: &str) {
let mut dropped = None;
if let Some(mut e) = self.v4.get_mut(ip)
&& e.conns.remove(network).is_some()
{
dropped = Some(e.endpoint_id);
}
self.v4.remove_if(ip, |_, e| e.conns.is_empty());
if let Some(mut e) = self.v6.get_mut(ipv6) {
e.conns.remove(network);
}
self.v6.remove_if(ipv6, |_, e| e.conns.is_empty());
if let (Some(endpoint_id), Some(audit)) = (dropped, &self.audit) {
audit.log_disconnect(*ip, &endpoint_id.to_string());
}
}
pub fn all_connections(&self) -> Vec<(Ipv4Addr, Connection)> {
self.v4
.iter()
.filter_map(|e| e.route().map(|r| (*e.key(), r.conn)))
.collect()
}
pub fn remove_by_network(&self, network: &str) -> Vec<Ipv4Addr> {
let mut removed = Vec::new();
self.v4.retain(|ip, e| {
e.conns.remove(network);
if e.conns.is_empty() {
removed.push(*ip);
false
} else {
true
}
});
self.v6.retain(|_ip, e| {
e.conns.remove(network);
!e.conns.is_empty()
});
removed
}
pub fn peers_for_network(&self, network: &str) -> Vec<(EndpointId, Ipv4Addr)> {
self.v4
.iter()
.filter(|e| e.conns.contains_key(network))
.map(|e| (e.endpoint_id, *e.key()))
.collect()
}
pub fn peers_for_network_with_conn(
&self,
network: &str,
) -> Vec<(EndpointId, Ipv4Addr, Connection)> {
self.v4
.iter()
.filter_map(|e| {
e.conns
.get(network)
.map(|c| (e.endpoint_id, *e.key(), c.clone()))
})
.collect()
}
#[cfg(test)]
pub fn all_peer_ids(&self) -> Vec<(Ipv4Addr, EndpointId)> {
self.v4.iter().map(|e| (*e.key(), e.endpoint_id)).collect()
}
}
#[derive(Clone)]
pub struct DeviceUserMap {
inner: Arc<FastDashMap<EndpointId, EndpointId>>,
}
impl Default for DeviceUserMap {
fn default() -> Self {
Self::new()
}
}
impl DeviceUserMap {
pub fn new() -> Self {
Self {
inner: Arc::new(FastDashMap::default()),
}
}
pub fn insert(&self, device_key: EndpointId, user_identity: EndpointId) {
self.inner.insert(device_key, user_identity);
}
pub fn resolve(&self, transport_key: &EndpointId) -> EndpointId {
self.inner
.get(transport_key)
.map(|e| *e.value())
.unwrap_or(*transport_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_table_empty_lookup() {
let table = PeerTable::new();
assert!(table.lookup_v4(&Ipv4Addr::new(100, 64, 0, 5)).is_none());
assert!(
table
.lookup_v6(&Ipv6Addr::new(0x0200, 0, 0, 0, 0, 0, 0, 1))
.is_none()
);
}
#[test]
fn test_peer_table_empty_ids() {
let table = PeerTable::new();
assert!(table.all_peer_ids().is_empty());
}
}