pub mod config;
pub mod state;
pub mod store;
pub use config::RateLimitConfig;
pub use state::{RateLimitDecision, RateLimitState, TokenBucket};
pub use store::RateLimitStore;
use ahash::AHashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridRateLimiterStats {
pub l1_size: usize,
pub l1_capacity: usize,
pub l1_hits: u64,
pub l1_misses: u64,
pub l1_evictions: u64,
pub banned_agents: Vec<String>,
}
#[derive(Debug, Clone)]
struct L1Entry {
state: RateLimitState,
last_accessed: std::time::Instant,
}
pub struct HybridRateLimiter {
store: RateLimitStore,
config: RateLimitConfig,
l1: Arc<RwLock<AHashMap<String, L1Entry>>>,
l1_capacity: usize,
banned: Arc<RwLock<HashMap<String, String>>>, stats: Arc<RwLock<HybridRateLimiterStats>>,
}
impl HybridRateLimiter {
pub fn new(
graph: &sqlitegraph::SqliteGraph,
config: RateLimitConfig,
l1_capacity: usize,
) -> crate::error::Result<Self> {
let store = RateLimitStore::new();
let mut l1_map = AHashMap::new();
if let Ok(states) = store.load_all(graph) {
for state in states {
l1_map.insert(
state.agent_id.clone(),
L1Entry {
state,
last_accessed: std::time::Instant::now(),
},
);
}
}
let l1 = Arc::new(RwLock::new(l1_map));
let stats = HybridRateLimiterStats {
l1_size: l1.read().len(),
l1_capacity,
l1_hits: 0,
l1_misses: 0,
l1_evictions: 0,
banned_agents: Vec::new(),
};
Ok(Self {
store,
config,
l1,
l1_capacity,
banned: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(stats)),
})
}
pub fn check_rate_limit(
&self,
graph: &sqlitegraph::SqliteGraph,
agent_id: &str,
) -> RateLimitDecision {
if self.banned.read().contains_key(agent_id) {
return RateLimitDecision {
allowed: false,
retry_after: Some(Duration::from_secs(3600)), };
}
{
let mut l1 = self.l1.write();
if let Some(entry) = l1.get_mut(agent_id) {
entry.last_accessed = std::time::Instant::now();
self.stats.write().l1_hits += 1;
return entry.state.check(1);
}
}
self.stats.write().l1_misses += 1;
if let Ok(Some(state)) = self.store.load(graph, agent_id) {
let mut l1 = self.l1.write();
if l1.len() >= self.l1_capacity {
self.evict_lru(graph, &mut l1);
}
l1.insert(
agent_id.to_string(),
L1Entry {
state,
last_accessed: std::time::Instant::now(),
},
);
self.stats.write().l1_size = l1.len();
if let Some(entry) = l1.get_mut(agent_id) {
return entry.state.check(1);
}
}
let mut state =
RateLimitState::new(agent_id, self.config.max_tokens, self.config.replenish_rate);
let decision = state.check(1);
let mut l1 = self.l1.write();
if l1.len() >= self.l1_capacity {
self.evict_lru(graph, &mut l1);
}
l1.insert(
agent_id.to_string(),
L1Entry {
state,
last_accessed: std::time::Instant::now(),
},
);
self.stats.write().l1_size = l1.len();
decision
}
fn evict_lru(&self, graph: &sqlitegraph::SqliteGraph, l1: &mut AHashMap<String, L1Entry>) {
let lru_key = l1
.iter()
.min_by_key(|(_, v)| v.last_accessed)
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
if let Some(entry) = l1.remove(&key) {
let _ = self.store.persist(graph, &entry.state);
self.stats.write().l1_evictions += 1;
}
}
}
pub fn replenish_all(&self, elapsed: Duration) {
let mut l1 = self.l1.write();
for entry in l1.values_mut() {
entry.state.replenish(elapsed);
}
}
pub fn ban_agent(
&self,
graph: &sqlitegraph::SqliteGraph,
agent_id: &str,
reason: &str,
) -> crate::error::Result<()> {
self.banned
.write()
.insert(agent_id.to_string(), reason.to_string());
self.l1.write().remove(agent_id);
let _ = self.store.persist_ban(graph, agent_id, reason);
let mut stats = self.stats.write();
stats.banned_agents = self.banned.read().keys().cloned().collect();
stats.l1_size = self.l1.read().len();
Ok(())
}
pub fn unban_agent(
&self,
graph: &sqlitegraph::SqliteGraph,
agent_id: &str,
) -> crate::error::Result<()> {
self.banned.write().remove(agent_id);
let _ = self.store.remove_ban(graph, agent_id);
let mut stats = self.stats.write();
stats.banned_agents = self.banned.read().keys().cloned().collect();
Ok(())
}
pub fn stats(&self) -> HybridRateLimiterStats {
self.stats.read().clone()
}
pub fn store(&self) -> &RateLimitStore {
&self.store
}
#[doc(hidden)]
pub fn clear_l1_for_testing(&self) {
self.l1.write().clear();
self.stats.write().l1_size = 0;
}
}