use crate::network::protocol::NetworkAddress;
use crate::utils::current_timestamp;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
#[cfg(feature = "iroh")]
use iroh::PublicKey;
#[derive(Debug, Clone)]
pub struct AddressEntry {
pub addr: NetworkAddress,
pub first_seen: u64,
pub last_seen: u64,
pub services: u64,
pub seen_count: u32,
}
impl AddressEntry {
pub fn new(addr: NetworkAddress, services: u64) -> Self {
let now = current_timestamp();
Self {
addr,
first_seen: now,
last_seen: now,
services,
seen_count: 1,
}
}
pub fn update_seen(&mut self) {
self.last_seen = current_timestamp();
self.seen_count += 1;
}
pub fn is_fresh(&self, expiration_seconds: u64) -> bool {
let now = current_timestamp();
now.saturating_sub(self.last_seen) < expiration_seconds
}
}
pub struct AddressDatabase {
addresses: HashMap<SocketAddr, AddressEntry>,
#[cfg(feature = "iroh")]
iroh_addresses: HashMap<PublicKey, AddressEntry>,
max_addresses: usize,
expiration_seconds: u64,
}
impl AddressDatabase {
pub fn new(max_addresses: usize) -> Self {
let capacity = (max_addresses * 4 / 3).next_power_of_two(); Self {
addresses: HashMap::with_capacity(capacity),
#[cfg(feature = "iroh")]
iroh_addresses: HashMap::with_capacity(capacity / 2), max_addresses,
expiration_seconds: 24 * 60 * 60, }
}
pub fn with_expiration(max_addresses: usize, expiration_seconds: u64) -> Self {
let capacity = (max_addresses * 4 / 3).next_power_of_two(); Self {
addresses: HashMap::with_capacity(capacity),
#[cfg(feature = "iroh")]
iroh_addresses: HashMap::with_capacity(capacity / 2), max_addresses,
expiration_seconds,
}
}
pub fn add_address(&mut self, addr: NetworkAddress, services: u64) {
let socket_addr = self.network_addr_to_socket(&addr);
match self.addresses.get_mut(&socket_addr) {
Some(entry) => {
entry.update_seen();
entry.services |= services; }
None => {
if self.total_count() >= self.max_addresses {
self.evict_oldest_unified();
}
self.addresses
.insert(socket_addr, AddressEntry::new(addr, services));
}
}
}
pub fn add_addresses(&mut self, addresses: Vec<NetworkAddress>, services: u64) {
for addr in addresses {
self.add_address(addr, services);
}
}
pub fn get_fresh_addresses(&self, count: usize) -> Vec<NetworkAddress> {
let mut fresh: Vec<_> = self
.addresses
.iter()
.filter(|(_, entry)| entry.is_fresh(self.expiration_seconds))
.map(|(_, entry)| (entry.last_seen, entry.addr.clone()))
.collect();
fresh.sort_by(|a, b| b.0.cmp(&a.0));
fresh
.into_iter()
.map(|(_, addr)| addr)
.take(count)
.collect()
}
pub fn get_all_fresh_addresses(&self) -> Vec<NetworkAddress> {
self.get_fresh_addresses(self.max_addresses)
}
pub fn remove_expired(&mut self) -> usize {
let before = self.addresses.len();
self.addresses
.retain(|_, entry| entry.is_fresh(self.expiration_seconds));
before - self.addresses.len()
}
pub fn remove_address(&mut self, addr: &NetworkAddress) {
let socket_addr = self.network_addr_to_socket(addr);
self.addresses.remove(&socket_addr);
}
pub fn is_banned(&self, addr: &NetworkAddress, ban_list: &HashMap<SocketAddr, u64>) -> bool {
let socket_addr = self.network_addr_to_socket(addr);
if let Some(unban_timestamp) = ban_list.get(&socket_addr) {
let now = current_timestamp();
if *unban_timestamp == u64::MAX || now < *unban_timestamp {
return true; }
}
false
}
pub fn is_local(&self, addr: &NetworkAddress) -> bool {
let socket = self.network_addr_to_socket(addr);
match socket.ip() {
IpAddr::V4(ipv4) => {
ipv4.is_loopback()
|| ipv4.is_private()
|| ipv4.is_link_local()
|| ipv4.is_broadcast()
}
IpAddr::V6(ipv6) => {
ipv6.is_loopback()
|| ipv6.is_unspecified()
|| (ipv6.octets()[0] == 0xfe && (ipv6.octets()[1] & 0xc0) == 0x80) || ipv6.octets()[0] == 0xfc || ipv6.octets()[0] == 0xfd || ipv6.octets()[0] == 0xff }
}
}
pub fn filter_addresses(
&self,
addresses: Vec<NetworkAddress>,
ban_list: &HashMap<SocketAddr, u64>,
connected_peers: &[SocketAddr],
) -> Vec<NetworkAddress> {
addresses
.into_iter()
.filter(|addr| {
let socket = self.network_addr_to_socket(addr);
!self.is_local(addr)
&& !self.is_banned(addr, ban_list)
&& !connected_peers.contains(&socket)
})
.collect()
}
#[cfg(feature = "iroh")]
pub fn add_iroh_address(&mut self, public_key: PublicKey, services: u64) {
match self.iroh_addresses.get_mut(&public_key) {
Some(entry) => {
entry.update_seen();
entry.services |= services;
}
None => {
if self.total_count() >= self.max_addresses {
self.evict_oldest_unified();
}
let placeholder_addr = NetworkAddress {
services,
ip: [0; 16],
port: 0,
};
self.iroh_addresses
.insert(public_key, AddressEntry::new(placeholder_addr, services));
}
}
}
#[cfg(feature = "iroh")]
pub fn get_fresh_iroh_addresses(&self, count: usize) -> Vec<PublicKey> {
let mut fresh: Vec<_> = self
.iroh_addresses
.iter()
.filter(|(_, entry)| entry.is_fresh(self.expiration_seconds))
.map(|(public_key, entry)| (entry.last_seen, *public_key))
.collect();
fresh.sort_by(|a, b| b.0.cmp(&a.0));
fresh
.into_iter()
.map(|(_, public_key)| public_key)
.take(count)
.collect()
}
pub fn total_count(&self) -> usize {
let socket_count = self.addresses.len();
#[cfg(feature = "iroh")]
{
socket_count + self.iroh_addresses.len()
}
#[cfg(not(feature = "iroh"))]
{
socket_count
}
}
pub fn len(&self) -> usize {
self.addresses.len()
}
pub fn is_empty(&self) -> bool {
#[cfg(feature = "iroh")]
{
self.addresses.is_empty() && self.iroh_addresses.is_empty()
}
#[cfg(not(feature = "iroh"))]
{
self.addresses.is_empty()
}
}
fn evict_oldest_unified(&mut self) {
let mut oldest_socket: Option<(SocketAddr, u64)> = None;
if let Some((addr, entry)) = self
.addresses
.iter()
.min_by_key(|(_, entry)| entry.last_seen)
{
oldest_socket = Some((*addr, entry.last_seen));
}
#[cfg(feature = "iroh")]
{
let mut oldest_iroh: Option<(PublicKey, u64)> = None;
if let Some((public_key, entry)) = self
.iroh_addresses
.iter()
.min_by_key(|(_, entry)| entry.last_seen)
{
oldest_iroh = Some((*public_key, entry.last_seen));
}
match (oldest_socket, oldest_iroh) {
(Some((socket_addr, socket_time)), Some((iroh_id, iroh_time))) => {
if socket_time <= iroh_time {
self.addresses.remove(&socket_addr);
} else {
self.iroh_addresses.remove(&iroh_id);
}
}
(Some((socket_addr, _)), None) => {
self.addresses.remove(&socket_addr);
}
(None, Some((iroh_id, _))) => {
self.iroh_addresses.remove(&iroh_id);
}
(None, None) => {
}
}
}
#[cfg(not(feature = "iroh"))]
{
if let Some((socket_addr, _)) = oldest_socket {
self.addresses.remove(&socket_addr);
}
}
}
#[deprecated(note = "Use evict_oldest_unified() instead")]
fn evict_oldest(&mut self) {
self.evict_oldest_unified();
}
#[cfg(feature = "iroh")]
#[deprecated(note = "Use evict_oldest_unified() instead")]
fn evict_oldest_iroh(&mut self) {
self.evict_oldest_unified();
}
pub fn network_addr_to_socket(&self, addr: &NetworkAddress) -> SocketAddr {
let ip = if addr.ip[0..12] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff] {
IpAddr::V4(std::net::Ipv4Addr::new(
addr.ip[12],
addr.ip[13],
addr.ip[14],
addr.ip[15],
))
} else {
let segments = [
u16::from_be_bytes([addr.ip[0], addr.ip[1]]),
u16::from_be_bytes([addr.ip[2], addr.ip[3]]),
u16::from_be_bytes([addr.ip[4], addr.ip[5]]),
u16::from_be_bytes([addr.ip[6], addr.ip[7]]),
u16::from_be_bytes([addr.ip[8], addr.ip[9]]),
u16::from_be_bytes([addr.ip[10], addr.ip[11]]),
u16::from_be_bytes([addr.ip[12], addr.ip[13]]),
u16::from_be_bytes([addr.ip[14], addr.ip[15]]),
];
IpAddr::V6(std::net::Ipv6Addr::new(
segments[0],
segments[1],
segments[2],
segments[3],
segments[4],
segments[5],
segments[6],
segments[7],
))
};
SocketAddr::new(ip, addr.port)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::IpAddr;
fn create_test_address(ip: &str, port: u16) -> NetworkAddress {
let socket = SocketAddr::new(ip.parse().unwrap(), port);
let ip_bytes = match socket.ip() {
IpAddr::V4(ipv4) => {
let mut bytes = [0u8; 16];
bytes[10] = 0xff;
bytes[11] = 0xff;
bytes[12..16].copy_from_slice(&ipv4.octets());
bytes
}
IpAddr::V6(ipv6) => ipv6.octets(),
};
NetworkAddress {
services: 0,
ip: ip_bytes,
port,
}
}
#[test]
fn test_address_database_creation() {
let db = AddressDatabase::new(100);
assert_eq!(db.len(), 0);
assert!(db.is_empty());
}
#[test]
fn test_add_address() {
let mut db = AddressDatabase::new(100);
let addr = create_test_address("192.168.1.1", 8333);
db.add_address(addr.clone(), 1);
assert_eq!(db.len(), 1);
assert!(!db.is_empty());
}
#[test]
fn test_add_duplicate_address() {
let mut db = AddressDatabase::new(100);
let addr = create_test_address("192.168.1.1", 8333);
db.add_address(addr.clone(), 1);
db.add_address(addr.clone(), 2);
assert_eq!(db.len(), 1); }
#[test]
fn test_get_fresh_addresses() {
let mut db = AddressDatabase::new(100);
let addr1 = create_test_address("192.168.1.1", 8333);
let addr2 = create_test_address("192.168.1.2", 8333);
db.add_address(addr1.clone(), 1);
db.add_address(addr2.clone(), 1);
let fresh = db.get_fresh_addresses(10);
assert_eq!(fresh.len(), 2);
}
#[test]
fn test_address_expiration() {
let mut db = AddressDatabase::with_expiration(100, 1); let addr = create_test_address("192.168.1.1", 8333);
db.add_address(addr.clone(), 1);
assert_eq!(db.len(), 1);
std::thread::sleep(std::time::Duration::from_secs(2));
let fresh = db.get_fresh_addresses(10);
assert_eq!(fresh.len(), 0); }
#[test]
fn test_remove_expired() {
let mut db = AddressDatabase::with_expiration(100, 1); let addr1 = create_test_address("192.168.1.1", 8333);
let addr2 = create_test_address("192.168.1.2", 8333);
db.add_address(addr1.clone(), 1);
db.add_address(addr2.clone(), 1);
assert_eq!(db.len(), 2);
std::thread::sleep(std::time::Duration::from_secs(2));
let removed = db.remove_expired();
assert_eq!(removed, 2);
assert_eq!(db.len(), 0);
}
#[test]
fn test_is_local() {
let db = AddressDatabase::new(100);
let localhost = create_test_address("127.0.0.1", 8333);
let private = create_test_address("192.168.1.1", 8333);
let public = create_test_address("8.8.8.8", 8333);
assert!(db.is_local(&localhost));
assert!(db.is_local(&private));
assert!(!db.is_local(&public));
}
#[test]
fn test_is_local_ipv6() {
let db = AddressDatabase::new(100);
let ipv6_localhost = create_test_address("::1", 8333);
assert!(db.is_local(&ipv6_localhost));
let ipv6_unspecified = create_test_address("::", 8333);
assert!(db.is_local(&ipv6_unspecified));
let ipv6_unique_local = create_test_address("fc00::1", 8333);
assert!(db.is_local(&ipv6_unique_local));
let ipv6_link_local = create_test_address("fe80::1", 8333);
assert!(db.is_local(&ipv6_link_local));
let ipv6_multicast = create_test_address("ff02::1", 8333);
assert!(db.is_local(&ipv6_multicast));
let ipv6_public = create_test_address("2001:4860:4860::8888", 8333);
assert!(!db.is_local(&ipv6_public));
}
#[test]
fn test_is_banned() {
let db = AddressDatabase::new(100);
let addr = create_test_address("192.168.1.1", 8333);
let socket = SocketAddr::new("192.168.1.1".parse().unwrap(), 8333);
let mut ban_list = HashMap::new();
assert!(!db.is_banned(&addr, &ban_list));
ban_list.insert(socket, u64::MAX);
assert!(db.is_banned(&addr, &ban_list));
ban_list.clear();
let future_time = crate::utils::current_timestamp() + 3600;
ban_list.insert(socket, future_time);
assert!(db.is_banned(&addr, &ban_list));
ban_list.clear();
let past_time = crate::utils::current_timestamp().saturating_sub(3600);
ban_list.insert(socket, past_time);
assert!(!db.is_banned(&addr, &ban_list));
}
#[test]
fn test_filter_addresses() {
let db = AddressDatabase::new(100);
let local = create_test_address("127.0.0.1", 8333);
let banned = create_test_address("192.168.1.1", 8333);
let public = create_test_address("8.8.8.8", 8333);
let socket_banned = SocketAddr::new("192.168.1.1".parse().unwrap(), 8333);
let socket_connected = SocketAddr::new("8.8.8.8".parse().unwrap(), 8333);
let mut ban_list = HashMap::new();
ban_list.insert(socket_banned, u64::MAX);
let connected_peers = vec![socket_connected];
let addresses = vec![local, banned, public];
let filtered = db.filter_addresses(addresses, &ban_list, &connected_peers);
assert_eq!(filtered.len(), 0);
}
#[test]
fn test_eviction_when_full() {
let mut db = AddressDatabase::new(2); let addr1 = create_test_address("192.168.1.1", 8333);
let addr2 = create_test_address("192.168.1.2", 8333);
let addr3 = create_test_address("192.168.1.3", 8333);
db.add_address(addr1.clone(), 1);
db.add_address(addr2.clone(), 1);
assert_eq!(db.len(), 2);
db.add_address(addr3.clone(), 1);
assert_eq!(db.len(), 2); }
#[cfg(feature = "iroh")]
#[test]
fn test_add_iroh_address() {
use getrandom::getrandom;
use iroh::SecretKey;
let mut db = AddressDatabase::new(100);
let mut bytes = [0u8; 32];
getrandom(&mut bytes).unwrap();
let secret_key = SecretKey::from_bytes(&bytes);
let public_key = secret_key.public();
db.add_iroh_address(public_key, 1);
assert_eq!(db.total_count(), 1);
db.add_iroh_address(public_key, 2);
assert_eq!(db.total_count(), 1);
}
#[cfg(feature = "iroh")]
#[test]
fn test_get_fresh_iroh_addresses() {
use getrandom::getrandom;
use iroh::SecretKey;
let mut db = AddressDatabase::new(100);
let mut bytes1 = [0u8; 32];
let mut bytes2 = [0u8; 32];
getrandom(&mut bytes1).unwrap();
getrandom(&mut bytes2).unwrap();
let secret_key1 = SecretKey::from_bytes(&bytes1);
let secret_key2 = SecretKey::from_bytes(&bytes2);
let public_key1 = secret_key1.public();
let public_key2 = secret_key2.public();
db.add_iroh_address(public_key1, 1);
db.add_iroh_address(public_key2, 1);
let fresh = db.get_fresh_iroh_addresses(10);
assert_eq!(fresh.len(), 2);
}
#[cfg(feature = "iroh")]
#[test]
fn test_total_count_includes_iroh() {
use iroh::PublicKey;
let mut db = AddressDatabase::new(100);
let addr = create_test_address("192.168.1.1", 8333);
db.add_address(addr, 1);
assert_eq!(db.total_count(), 1);
let key_bytes = [0u8; 32];
let public_key = PublicKey::from_bytes(&key_bytes).unwrap();
db.add_iroh_address(public_key, 1);
assert_eq!(db.total_count(), 2);
}
}