use anyhow::{Result, anyhow};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::num::NonZeroUsize;
const BOOTSTRAP_MAX_TRACKED_SUBNETS: usize = 50_000;
pub const IP_EXACT_LIMIT: usize = 2;
#[cfg(test)]
const DEFAULT_K_VALUE: usize = 20;
pub fn canonicalize_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => v6
.to_ipv4_mapped()
.map(IpAddr::V4)
.unwrap_or(IpAddr::V6(v6)),
other => other,
}
}
pub const fn ip_subnet_limit(k: usize) -> usize {
if k / 4 > 0 { k / 4 } else { 1 }
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IPDiversityConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_per_ip: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_per_subnet: Option<usize>,
}
impl IPDiversityConfig {
#[must_use]
pub fn testnet() -> Self {
Self::permissive()
}
#[must_use]
pub fn permissive() -> Self {
Self {
max_per_ip: Some(usize::MAX),
max_per_subnet: Some(usize::MAX),
}
}
pub fn validate(&self) -> Result<()> {
if let Some(limit) = self.max_per_ip
&& limit < 1
{
anyhow::bail!("max_per_ip must be >= 1 (got {limit})");
}
if let Some(limit) = self.max_per_subnet
&& limit < 1
{
anyhow::bail!("max_per_subnet must be >= 1 (got {limit})");
}
Ok(())
}
}
#[derive(Debug)]
pub struct BootstrapIpLimiter {
config: IPDiversityConfig,
allow_loopback: bool,
k_value: usize,
ip_counts: LruCache<IpAddr, usize>,
subnet_counts: LruCache<IpAddr, usize>,
}
impl BootstrapIpLimiter {
#[cfg(test)]
pub fn new(config: IPDiversityConfig) -> Self {
Self::with_loopback(config, false)
}
#[cfg(test)]
pub fn with_loopback(config: IPDiversityConfig, allow_loopback: bool) -> Self {
Self::with_loopback_and_k(config, allow_loopback, DEFAULT_K_VALUE)
}
pub fn with_loopback_and_k(
config: IPDiversityConfig,
allow_loopback: bool,
k_value: usize,
) -> Self {
let cache_size =
NonZeroUsize::new(BOOTSTRAP_MAX_TRACKED_SUBNETS).unwrap_or(NonZeroUsize::MIN);
Self {
config,
allow_loopback,
k_value,
ip_counts: LruCache::new(cache_size),
subnet_counts: LruCache::new(cache_size),
}
}
fn subnet_key(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V4(v4) => {
let o = v4.octets();
IpAddr::V4(Ipv4Addr::new(o[0], o[1], o[2], 0))
}
IpAddr::V6(v6) => {
let mut o = v6.octets();
for b in &mut o[6..] {
*b = 0;
}
IpAddr::V6(Ipv6Addr::from(o))
}
}
}
pub fn can_accept(&self, ip: IpAddr) -> bool {
let ip = canonicalize_ip(ip);
if ip.is_loopback() {
return self.allow_loopback;
}
if ip.is_unspecified() || ip.is_multicast() {
return false;
}
let ip_limit = self.config.max_per_ip.unwrap_or(IP_EXACT_LIMIT);
let subnet_limit = self
.config
.max_per_subnet
.unwrap_or(ip_subnet_limit(self.k_value));
if let Some(&count) = self.ip_counts.peek(&ip)
&& count >= ip_limit
{
return false;
}
let subnet = Self::subnet_key(ip);
if let Some(&count) = self.subnet_counts.peek(&subnet)
&& count >= subnet_limit
{
return false;
}
true
}
pub fn track(&mut self, ip: IpAddr) -> Result<()> {
let ip = canonicalize_ip(ip);
if !self.can_accept(ip) {
return Err(anyhow!("IP diversity limits exceeded"));
}
let count = self.ip_counts.get(&ip).copied().unwrap_or(0) + 1;
self.ip_counts.put(ip, count);
let subnet = Self::subnet_key(ip);
let count = self.subnet_counts.get(&subnet).copied().unwrap_or(0) + 1;
self.subnet_counts.put(subnet, count);
Ok(())
}
#[allow(dead_code)]
pub fn untrack(&mut self, ip: IpAddr) {
let ip = canonicalize_ip(ip);
if let Some(count) = self.ip_counts.peek_mut(&ip) {
*count = count.saturating_sub(1);
if *count == 0 {
self.ip_counts.pop(&ip);
}
}
let subnet = Self::subnet_key(ip);
if let Some(count) = self.subnet_counts.peek_mut(&subnet) {
*count = count.saturating_sub(1);
if *count == 0 {
self.subnet_counts.pop(&subnet);
}
}
}
}
#[cfg(test)]
impl BootstrapIpLimiter {
#[allow(dead_code)]
pub fn config(&self) -> &IPDiversityConfig {
&self.config
}
}
#[allow(dead_code)]
pub trait GeoProvider: std::fmt::Debug {
fn lookup(&self, ip: Ipv6Addr) -> GeoInfo;
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct GeoInfo {
pub asn: Option<u32>,
pub country: Option<String>,
pub is_hosting_provider: bool,
pub is_vpn_provider: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_diversity_config_default() {
let config = IPDiversityConfig::default();
assert!(config.max_per_ip.is_none());
assert!(config.max_per_subnet.is_none());
}
#[test]
fn test_bootstrap_ip_limiter_creation() {
let config = IPDiversityConfig {
max_per_ip: None,
max_per_subnet: Some(1),
};
let enforcer = BootstrapIpLimiter::with_loopback(config.clone(), true);
assert_eq!(enforcer.config.max_per_subnet, config.max_per_subnet);
}
#[test]
fn test_can_accept_basic() {
let config = IPDiversityConfig::default();
let enforcer = BootstrapIpLimiter::new(config);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(enforcer.can_accept(ip));
}
#[test]
fn test_ip_limit_enforcement() {
let config = IPDiversityConfig {
max_per_ip: Some(1),
max_per_subnet: Some(usize::MAX),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(enforcer.can_accept(ip));
enforcer.track(ip).unwrap();
assert!(!enforcer.can_accept(ip));
assert!(enforcer.track(ip).is_err());
}
#[test]
fn test_subnet_limit_enforcement_ipv4() {
let config = IPDiversityConfig {
max_per_ip: Some(usize::MAX),
max_per_subnet: Some(2),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ip1: IpAddr = "10.0.1.1".parse().unwrap();
let ip2: IpAddr = "10.0.1.2".parse().unwrap();
let ip3: IpAddr = "10.0.1.3".parse().unwrap();
enforcer.track(ip1).unwrap();
enforcer.track(ip2).unwrap();
assert!(!enforcer.can_accept(ip3));
assert!(enforcer.track(ip3).is_err());
let ip_other: IpAddr = "10.0.2.1".parse().unwrap();
assert!(enforcer.can_accept(ip_other));
}
#[test]
fn test_subnet_limit_enforcement_ipv6() {
let config = IPDiversityConfig {
max_per_ip: Some(usize::MAX),
max_per_subnet: Some(1),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ip1: IpAddr = "2001:db8:85a3:1234::1".parse().unwrap();
let ip2: IpAddr = "2001:db8:85a3:5678::2".parse().unwrap();
enforcer.track(ip1).unwrap();
assert!(!enforcer.can_accept(ip2));
let ip_other: IpAddr = "2001:db8:aaaa::1".parse().unwrap();
assert!(enforcer.can_accept(ip_other));
}
#[test]
fn test_track_and_untrack() {
let config = IPDiversityConfig {
max_per_ip: Some(1),
max_per_subnet: Some(usize::MAX),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
enforcer.track(ip).unwrap();
assert!(!enforcer.can_accept(ip));
enforcer.untrack(ip);
assert!(enforcer.can_accept(ip));
enforcer.track(ip).unwrap();
assert!(!enforcer.can_accept(ip));
}
#[test]
fn test_loopback_bypass() {
let config = IPDiversityConfig {
max_per_ip: Some(1),
max_per_subnet: Some(1),
};
let enforcer = BootstrapIpLimiter::with_loopback(config.clone(), true);
let loopback_v4: IpAddr = "127.0.0.1".parse().unwrap();
let loopback_v6: IpAddr = "::1".parse().unwrap();
assert!(enforcer.can_accept(loopback_v4));
assert!(enforcer.can_accept(loopback_v6));
let enforcer_no_lb = BootstrapIpLimiter::new(config);
assert!(
!enforcer_no_lb.can_accept(loopback_v4),
"loopback should be rejected when allow_loopback=false"
);
assert!(
!enforcer_no_lb.can_accept(loopback_v6),
"loopback IPv6 should be rejected when allow_loopback=false"
);
}
#[test]
fn test_subnet_key_ipv4() {
let ip: IpAddr = "192.168.42.100".parse().unwrap();
let subnet = BootstrapIpLimiter::subnet_key(ip);
let expected: IpAddr = "192.168.42.0".parse().unwrap();
assert_eq!(subnet, expected);
}
#[test]
fn test_subnet_key_ipv6() {
let ip: IpAddr = "2001:db8:85a3:1234:5678:8a2e:0370:7334".parse().unwrap();
let subnet = BootstrapIpLimiter::subnet_key(ip);
let expected: IpAddr = "2001:db8:85a3::".parse().unwrap();
assert_eq!(subnet, expected);
}
#[test]
fn test_default_ip_limit_is_two() {
let config = IPDiversityConfig::default();
let mut enforcer = BootstrapIpLimiter::new(config);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
enforcer.track(ip1).unwrap();
enforcer.track(ip1).unwrap();
assert!(!enforcer.can_accept(ip1));
}
#[test]
fn test_default_subnet_limit_matches_k() {
let config = IPDiversityConfig::default();
let mut enforcer = BootstrapIpLimiter::new(config);
for i in 1..=5 {
let ip: IpAddr = format!("10.0.1.{i}").parse().unwrap();
enforcer.track(ip).unwrap();
}
let ip6: IpAddr = "10.0.1.6".parse().unwrap();
assert!(
!enforcer.can_accept(ip6),
"6th peer in same /24 should exceed K/4=5 subnet limit"
);
}
#[test]
fn test_ipv4_mapped_ipv6_counts_as_ipv4() {
let config = IPDiversityConfig {
max_per_ip: Some(1),
max_per_subnet: Some(usize::MAX),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ipv4: IpAddr = "10.0.0.1".parse().unwrap();
enforcer.track(ipv4).unwrap();
let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
assert!(
!enforcer.can_accept(mapped),
"IPv4-mapped IPv6 should be canonicalized and hit the IPv4 limit"
);
}
#[test]
fn test_multicast_rejected() {
let config = IPDiversityConfig::default();
let enforcer = BootstrapIpLimiter::new(config);
let multicast_v4: IpAddr = "224.0.0.1".parse().unwrap();
assert!(!enforcer.can_accept(multicast_v4));
let multicast_v6: IpAddr = "ff02::1".parse().unwrap();
assert!(!enforcer.can_accept(multicast_v6));
}
#[test]
fn test_unspecified_rejected() {
let config = IPDiversityConfig::default();
let enforcer = BootstrapIpLimiter::new(config);
let unspec_v4: IpAddr = "0.0.0.0".parse().unwrap();
assert!(!enforcer.can_accept(unspec_v4));
let unspec_v6: IpAddr = "::".parse().unwrap();
assert!(!enforcer.can_accept(unspec_v6));
}
#[test]
fn test_untrack_ipv4_mapped_ipv6() {
let config = IPDiversityConfig {
max_per_ip: Some(1),
max_per_subnet: Some(usize::MAX),
};
let mut enforcer = BootstrapIpLimiter::new(config);
let ipv4: IpAddr = "10.0.0.1".parse().unwrap();
enforcer.track(ipv4).unwrap();
assert!(!enforcer.can_accept(ipv4));
let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
enforcer.untrack(mapped);
assert!(
enforcer.can_accept(ipv4),
"untrack via mapped form should decrement the IPv4 counter"
);
}
}