use crate::{discv5::PERMIT_BAN_LIST, metrics::METRICS, node_info::NodeAddress, packet::Packet};
use cache::ReceivedPacketCache;
use enr::NodeId;
use hashlink::LruCache;
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
sync::atomic::Ordering,
time::{Duration, Instant},
};
use tracing::{debug, warn};
mod cache;
mod config;
pub mod rate_limiter;
pub use config::FilterConfig;
use rate_limiter::{LimitKind, RateLimiter};
const KNOWN_ADDRS_SIZE: usize = 500;
const BANNED_NODES_SIZE: usize = 50;
const DEFAULT_PACKETS_PER_SECOND: usize = 20;
pub(crate) struct Filter {
enabled: bool,
rate_limiter: Option<RateLimiter>,
raw_packets_received: ReceivedPacketCache,
ban_duration: Option<Duration>,
known_addrs: LruCache<IpAddr, HashSet<NodeId>>,
banned_nodes: LruCache<IpAddr, usize>,
pub max_nodes_per_ip: Option<usize>,
pub max_bans_per_ip: Option<usize>,
}
impl Filter {
pub fn new(config: FilterConfig, ban_duration: Option<Duration>) -> Filter {
let expected_packets_per_second = config
.rate_limiter
.as_ref()
.map(|v| v.total_requests_per_second().round() as usize)
.unwrap_or(DEFAULT_PACKETS_PER_SECOND);
Filter {
enabled: config.enabled,
rate_limiter: config.rate_limiter,
raw_packets_received: ReceivedPacketCache::new(
expected_packets_per_second,
METRICS.moving_window,
),
known_addrs: LruCache::new(KNOWN_ADDRS_SIZE),
banned_nodes: LruCache::new(BANNED_NODES_SIZE),
ban_duration,
max_nodes_per_ip: config.max_nodes_per_ip,
max_bans_per_ip: config.max_bans_per_ip,
}
}
pub fn initial_pass(&mut self, src: &SocketAddr) -> bool {
if PERMIT_BAN_LIST.read().permit_ips.contains(&src.ip()) {
return true;
}
if PERMIT_BAN_LIST.read().ban_ips.contains_key(&src.ip()) {
debug!(?src, "Dropped unsolicited packet from banned src");
return false;
}
self.raw_packets_received.cache_insert();
METRICS
.unsolicited_requests_per_window
.store(self.raw_packets_received.len(), Ordering::Relaxed);
if !self.enabled {
return true;
}
if let Some(rate_limiter) = self.rate_limiter.as_mut() {
if rate_limiter.allows(&LimitKind::Ip(src.ip())).is_err() {
warn!(ip = ?src.ip(), "Banning IP for excessive requests");
let ban_timeout = self.ban_duration.map(|v| Instant::now() + v);
PERMIT_BAN_LIST
.write()
.ban_ips
.insert(src.ip(), ban_timeout);
return false;
}
if rate_limiter.allows(&LimitKind::Total).is_err() {
debug!(ip = ?src.ip(), "Dropped unsolicited packet from RPC limit");
return false;
}
}
true
}
pub fn final_pass(&mut self, node_address: &NodeAddress, _packet: &Packet) -> bool {
if PERMIT_BAN_LIST
.read()
.permit_nodes
.contains(&node_address.node_id)
{
return true;
}
if PERMIT_BAN_LIST
.read()
.ban_nodes
.contains_key(&node_address.node_id)
{
debug!(
node = %node_address,
"Dropped unsolicited packet from banned node_id",
);
return false;
}
if !self.enabled {
return true;
}
if let Some(rate_limiter) = self.rate_limiter.as_mut() {
if rate_limiter
.allows(&LimitKind::NodeId(node_address.node_id))
.is_err()
{
warn!(
node_id = %node_address.node_id,
"Node has exceeded its request limit and is now banned",
);
let ban_timeout = self.ban_duration.map(|v| Instant::now() + v);
PERMIT_BAN_LIST
.write()
.ban_nodes
.insert(node_address.node_id, ban_timeout);
if let Some(max_bans_per_ip) = self.max_bans_per_ip {
let ip = node_address.socket_addr.ip();
if let Some(banned_count) = self.banned_nodes.get_mut(&ip) {
*banned_count += 1;
if *banned_count >= max_bans_per_ip {
PERMIT_BAN_LIST.write().ban_ips.insert(ip, ban_timeout);
}
} else {
self.banned_nodes.insert(ip, 0);
}
}
return false;
}
}
if let Some(max_nodes_per_ip) = self.max_nodes_per_ip {
let ip = node_address.socket_addr.ip();
let known_nodes = {
if let Some(known_nodes) = self.known_addrs.get_mut(&ip) {
known_nodes.insert(node_address.node_id);
known_nodes.len()
} else {
let mut ids = HashSet::new();
ids.insert(node_address.node_id);
self.known_addrs.insert(ip, ids);
1
}
};
if known_nodes >= max_nodes_per_ip {
warn!(%ip, "IP has exceeded its node-id limit and is now banned");
let ban_timeout = self.ban_duration.map(|v| Instant::now() + v);
PERMIT_BAN_LIST.write().ban_ips.insert(ip, ban_timeout);
self.known_addrs.remove(&ip);
return false;
}
}
true
}
pub fn prune_limiter(&mut self) {
if let Some(rate_limiter) = self.rate_limiter.as_mut() {
rate_limiter.prune();
}
}
}