use std::collections::HashMap;
use std::time::Instant;
use hirn_core::HirnResult;
use hirn_core::types::AgentId;
use tokio::sync::Mutex;
use crate::admission::{AdmissionController, AdmissionDecision, MemoryCandidate};
pub struct RateLimiter {
max_writes: u64,
window_secs: u64,
state: Mutex<HashMap<AgentId, Vec<Instant>>>,
}
impl RateLimiter {
pub fn new(max_writes: u64, window_secs: u64) -> Self {
Self {
max_writes,
window_secs,
state: Mutex::new(HashMap::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(100, 60)
}
fn prune(timestamps: &mut Vec<Instant>, now: Instant, window: std::time::Duration) {
timestamps.retain(|ts| now.duration_since(*ts) < window);
}
}
#[async_trait::async_trait]
impl AdmissionController for RateLimiter {
fn name(&self) -> &str {
"rate_limiter"
}
async fn evaluate(&self, candidate: &MemoryCandidate) -> HirnResult<AdmissionDecision> {
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_secs);
let mut state = self.state.lock().await;
let timestamps = state.entry(candidate.agent_id.clone()).or_default();
Self::prune(timestamps, now, window);
let current_count = timestamps.len() as u64;
if current_count >= self.max_writes {
Ok(AdmissionDecision::Reject {
reason: format!(
"rate limit exceeded: {current_count}/{max} writes/{window}s for agent '{agent}'",
max = self.max_writes,
window = self.window_secs,
agent = candidate.agent_id.as_str(),
),
})
} else {
timestamps.push(now);
Ok(AdmissionDecision::Accept {
importance_override: None,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hirn_core::id::MemoryId;
use hirn_core::metadata::Metadata;
use hirn_core::types::{AgentId, Namespace};
fn candidate(agent: &str) -> MemoryCandidate {
MemoryCandidate {
id: MemoryId::new(),
content: "test".into(),
entities: vec![],
embedding: None,
agent_id: AgentId::new(agent).unwrap(),
namespace: Namespace::shared(),
importance: 0.5,
surprise: 0.5,
metadata: Metadata::default(),
}
}
#[tokio::test]
async fn within_limit_accepted() {
let limiter = RateLimiter::new(5, 60);
for _ in 0..5 {
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_accept());
}
}
#[tokio::test]
async fn exceeds_limit_rejected() {
let limiter = RateLimiter::new(3, 60);
for _ in 0..3 {
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_accept());
}
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_reject());
}
#[tokio::test]
async fn two_agents_independent() {
let limiter = RateLimiter::new(2, 60);
for _ in 0..2 {
limiter.evaluate(&candidate("agent-a")).await.unwrap();
}
let result_a = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result_a.is_reject());
let result_b = limiter.evaluate(&candidate("agent-b")).await.unwrap();
assert!(result_b.is_accept());
}
#[tokio::test]
async fn window_slides() {
let limiter = RateLimiter::new(2, 0);
for _ in 0..10 {
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_accept());
}
}
#[tokio::test]
async fn default_limiter() {
let limiter = RateLimiter::with_defaults();
for _ in 0..100 {
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_accept());
}
let result = limiter.evaluate(&candidate("agent-a")).await.unwrap();
assert!(result.is_reject());
}
}