use std::net::IpAddr;
use std::num::NonZeroUsize;
use lru::LruCache;
use parking_lot::RwLock;
use tracing::debug;
use super::api::PowPenaltyReason;
use super::config::PowConfig;
use super::error::PowError;
use super::extractors::AddressKey;
use crate::prelude::*;
#[derive(Debug, Clone)]
pub struct PowCounterEntry {
pub counter: u32,
pub last_incremented: Timestamp,
pub reason: PowPenaltyReason,
}
pub struct PowCounterStore {
individual: RwLock<LruCache<AddressKey, PowCounterEntry>>,
network: RwLock<LruCache<AddressKey, PowCounterEntry>>,
config: PowConfig,
}
impl PowCounterStore {
pub fn new(config: PowConfig) -> Self {
const FIFTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(50_000) {
Some(v) => v,
None => unreachable!(),
};
const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
Some(v) => v,
None => unreachable!(),
};
let individual_cap =
NonZeroUsize::new(config.max_individual_entries).unwrap_or(FIFTY_THOUSAND);
let network_cap = NonZeroUsize::new(config.max_network_entries).unwrap_or(TEN_THOUSAND);
Self {
individual: RwLock::new(LruCache::new(individual_cap)),
network: RwLock::new(LruCache::new(network_cap)),
config,
}
}
pub fn get_requirement(&self, addr: &IpAddr) -> u32 {
let individual_key = AddressKey::from_ip_individual(addr);
let network_key = AddressKey::from_ip_network(addr);
let individual_count = self.get_counter_value(&self.individual, &individual_key);
let network_count = self.get_counter_value(&self.network, &network_key);
individual_count.max(network_count)
}
pub fn verify(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
let required = self.get_requirement(addr);
if required == 0 {
return Ok(());
}
let suffix = "A".repeat(required as usize);
if token.ends_with(&suffix) {
Ok(())
} else {
Err(PowError::InsufficientWork { required, suffix })
}
}
pub fn increment(&self, addr: &IpAddr, reason: PowPenaltyReason) {
let individual_key = AddressKey::from_ip_individual(addr);
self.increment_entry(&self.individual, individual_key.clone(), reason);
if reason.affects_network() {
let network_key = AddressKey::from_ip_network(addr);
self.increment_entry(&self.network, network_key, reason);
}
debug!(
"PoW counter incremented for {:?} (reason: {:?}), new requirement: {}",
individual_key,
reason,
self.get_requirement(addr)
);
}
pub fn decrement(&self, addr: &IpAddr, amount: u32) {
let individual_key = AddressKey::from_ip_individual(addr);
self.decrement_entry(&self.individual, &individual_key, amount);
let network_key = AddressKey::from_ip_network(addr);
self.decrement_entry(&self.network, &network_key, amount);
}
fn get_counter_value(
&self,
cache: &RwLock<LruCache<AddressKey, PowCounterEntry>>,
key: &AddressKey,
) -> u32 {
let cache = cache.read();
if let Some(entry) = cache.peek(key) { self.apply_decay(entry) } else { 0 }
}
fn apply_decay(&self, entry: &PowCounterEntry) -> u32 {
let now = Timestamp::now();
let elapsed_secs = u64::try_from((now.0 - entry.last_incremented.0).max(0)).unwrap_or(0);
let decay =
u32::try_from(elapsed_secs / self.config.decay_interval_secs).unwrap_or(u32::MAX);
entry.counter.saturating_sub(decay)
}
fn increment_entry(
&self,
cache: &RwLock<LruCache<AddressKey, PowCounterEntry>>,
key: AddressKey,
reason: PowPenaltyReason,
) {
let mut cache = cache.write();
let now = Timestamp::now();
if let Some(entry) = cache.get_mut(&key) {
let decayed = self.apply_decay(entry);
entry.counter = decayed.saturating_add(1).min(self.config.max_counter);
entry.last_incremented = now;
entry.reason = reason;
} else {
cache.put(key, PowCounterEntry { counter: 1, last_incremented: now, reason });
}
}
fn decrement_entry(
&self,
cache: &RwLock<LruCache<AddressKey, PowCounterEntry>>,
key: &AddressKey,
amount: u32,
) {
let mut cache = cache.write();
if let Some(entry) = cache.get_mut(key) {
let decayed = self.apply_decay(entry);
let new_value = decayed.saturating_sub(amount);
if new_value == 0 {
cache.pop(key);
} else {
entry.counter = new_value;
entry.last_incremented = Timestamp::now();
}
}
}
pub fn individual_count(&self) -> usize {
self.individual.read().len()
}
pub fn network_count(&self) -> usize {
self.network.read().len()
}
}
impl Default for PowCounterStore {
fn default() -> Self {
Self::new(PowConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_pow_store_basic() {
let store = PowCounterStore::default();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
assert_eq!(store.get_requirement(&ip), 0);
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
assert_eq!(store.get_requirement(&ip), 1);
store.increment(&ip, PowPenaltyReason::ConnDuplicatePending);
assert_eq!(store.get_requirement(&ip), 2);
}
#[test]
fn test_pow_verification() {
let store = PowCounterStore::default();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
assert!(store.verify(&ip, "some_token").is_ok());
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
assert!(store.verify(&ip, "some_token").is_err());
assert!(store.verify(&ip, "some_tokenAA").is_err());
assert!(store.verify(&ip, "some_tokenAAA").is_ok());
assert!(store.verify(&ip, "some_tokenAAAA").is_ok());
}
#[test]
fn test_pow_network_level() {
let store = PowCounterStore::default();
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 200));
store.increment(&ip1, PowPenaltyReason::ConnSignatureFailure);
assert!(store.get_requirement(&ip2) >= 1);
}
#[test]
fn test_pow_individual_only() {
let store = PowCounterStore::default();
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 200));
store.increment(&ip1, PowPenaltyReason::ConnRejected);
assert_eq!(store.get_requirement(&ip1), 1);
assert_eq!(store.get_requirement(&ip2), 0);
}
#[test]
fn test_pow_max_counter() {
let config = PowConfig { max_counter: 3, ..Default::default() };
let store = PowCounterStore::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
for _ in 0..10 {
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
}
assert_eq!(store.get_requirement(&ip), 3);
}
#[test]
fn test_pow_decrement() {
let store = PowCounterStore::default();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
store.increment(&ip, PowPenaltyReason::ConnSignatureFailure);
assert_eq!(store.get_requirement(&ip), 3);
store.decrement(&ip, 1);
assert_eq!(store.get_requirement(&ip), 2);
store.decrement(&ip, 10);
assert_eq!(store.get_requirement(&ip), 0);
}
}