use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use http::{Request, Response};
use tower::Service;
use super::store::SlidingWindowStore;
use crate::http::{Body, BoxError, HttpService};
type KeyFn = Arc<dyn Fn(&Request<Body>) -> String + Send + Sync>;
const DEFAULT_MAX_KEYS: usize = 10_000;
const DEFAULT_IDLE_TTL: Duration = Duration::from_secs(600);
const CLEANUP_INTERVAL: Duration = Duration::from_secs(30);
struct SharedState {
windows: HashMap<String, WindowState>,
count: u64,
window: Duration,
max_keys: usize,
idle_ttl: Duration,
next_cleanup: Instant,
}
struct WindowState {
timestamps: VecDeque<Instant>,
last_seen: Instant,
}
impl SharedState {
fn effective_ttl(&self) -> Duration {
self.idle_ttl.max(self.window)
}
fn maybe_cleanup(&mut self, now: Instant) {
if now < self.next_cleanup {
return;
}
let ttl = self.effective_ttl();
self.windows
.retain(|_, state| now.saturating_duration_since(state.last_seen) <= ttl);
self.next_cleanup = now + CLEANUP_INTERVAL;
}
fn evict_if_needed(&mut self, key: &str, now: Instant) {
if self.windows.contains_key(key) || self.windows.len() < self.max_keys {
return;
}
let ttl = self.effective_ttl();
if let Some(oldest_key) = self
.windows
.iter()
.filter(|(_, state)| now.saturating_duration_since(state.last_seen) > ttl)
.min_by_key(|(_, state)| state.last_seen)
.map(|(k, _)| k.clone())
{
self.windows.remove(&oldest_key);
}
}
fn take(&mut self, key: &str) -> Option<Duration> {
let now = Instant::now();
self.maybe_cleanup(now);
self.evict_if_needed(key, now);
let cutoff = now - self.window;
let state = self
.windows
.entry(key.to_string())
.or_insert_with(|| WindowState {
timestamps: VecDeque::new(),
last_seen: now,
});
state.last_seen = now;
let timestamps = &mut state.timestamps;
while timestamps.front().is_some_and(|&t| t <= cutoff) {
timestamps.pop_front();
}
if (timestamps.len() as u64) < self.count {
timestamps.push_back(now);
None
} else {
let oldest = timestamps[0];
let delay = self.window - now.duration_since(oldest);
let reserved = now + delay;
timestamps.push_back(reserved);
Some(delay)
}
}
}
#[derive(Clone)]
pub struct InMemorySlidingWindowStore {
state: Arc<Mutex<SharedState>>,
}
impl InMemorySlidingWindowStore {
pub(crate) fn new(count: u64, window: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(SharedState {
windows: HashMap::new(),
count,
window,
max_keys: DEFAULT_MAX_KEYS,
idle_ttl: DEFAULT_IDLE_TTL,
next_cleanup: Instant::now() + CLEANUP_INTERVAL,
})),
}
}
pub(crate) fn set_max_keys(&self, max: usize) {
self.state.lock().unwrap().max_keys = max.max(1);
}
pub(crate) fn set_idle_ttl(&self, ttl: Duration) {
self.state.lock().unwrap().idle_ttl = ttl;
}
}
impl SlidingWindowStore for InMemorySlidingWindowStore {
fn take(&self, key: &str) -> impl Future<Output = Option<Duration>> + Send {
let result = self.state.lock().unwrap().take(key);
std::future::ready(result)
}
}
pub struct SlidingWindow<S: SlidingWindowStore = InMemorySlidingWindowStore> {
store: S,
key_fn: KeyFn,
}
impl<S: SlidingWindowStore> Clone for SlidingWindow<S> {
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
key_fn: self.key_fn.clone(),
}
}
}
impl<S: SlidingWindowStore> SlidingWindow<S> {
pub fn with_store(
store: S,
key_fn: impl Fn(&Request<Body>) -> String + Send + Sync + 'static,
) -> Self {
Self {
store,
key_fn: Arc::new(key_fn),
}
}
}
impl SlidingWindow {
pub fn keyed(
count: u64,
window: Duration,
key_fn: impl Fn(&Request<Body>) -> String + Send + Sync + 'static,
) -> Self {
Self {
store: InMemorySlidingWindowStore::new(count, window),
key_fn: Arc::new(key_fn),
}
}
pub fn global(count: u64, window: Duration) -> Self {
Self::keyed(count, window, |_| String::new())
}
pub fn per_host(count: u64, window: Duration) -> Self {
Self::keyed(count, window, extract_host)
}
pub fn max_keys(self, max: usize) -> Self {
self.store.set_max_keys(max);
self
}
pub fn idle_ttl(self, ttl: Duration) -> Self {
self.store.set_idle_ttl(ttl);
self
}
}
fn extract_host(req: &Request<Body>) -> String {
req.uri()
.host()
.or_else(|| req.headers().get(http::header::HOST)?.to_str().ok())
.map(|h| h.split(':').next().unwrap_or(h))
.unwrap_or("unknown")
.to_string()
}
impl<S: SlidingWindowStore> tower::Layer<HttpService> for SlidingWindow<S> {
type Service = SlidingWindowService<S>;
fn layer(&self, inner: HttpService) -> Self::Service {
SlidingWindowService {
inner,
store: self.store.clone(),
key_fn: self.key_fn.clone(),
}
}
}
pub struct SlidingWindowService<S: SlidingWindowStore = InMemorySlidingWindowStore> {
inner: HttpService,
store: S,
key_fn: KeyFn,
}
impl<S: SlidingWindowStore> Service<Request<Body>> for SlidingWindowService<S> {
type Response = Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let key = (self.key_fn)(&req);
let store = self.store.clone();
let fut = self.inner.call(req);
Box::pin(async move {
if let Some(delay) = store.take(&key).await {
tokio::time::sleep(delay).await;
}
fut.await
})
}
}
#[cfg(feature = "redis")]
mod redis_impl {
use std::time::Duration;
use super::super::store::SlidingWindowStore;
use super::InMemorySlidingWindowStore;
use crate::redis::RedisConnection;
const SLIDING_WINDOW_LUA: &str = r#"
local key = KEYS[1]
local count = tonumber(ARGV[1])
local window_ms = tonumber(ARGV[2])
local member = ARGV[3]
local t = redis.call('TIME')
local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)
local cutoff = now_ms - window_ms
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
local current = redis.call('ZCARD', key)
if current < count then
redis.call('ZADD', key, now_ms, member)
redis.call('PEXPIRE', key, window_ms + 1000)
return 0
else
local oldest = redis.call('ZRANGEBYSCORE', key, '-inf', '+inf', 'WITHSCORES', 'LIMIT', 0, 1)
if #oldest >= 2 then
local oldest_ms = tonumber(oldest[2])
local delay_ms = window_ms - (now_ms - oldest_ms)
if delay_ms < 0 then delay_ms = 0 end
local reserved_ms = now_ms + delay_ms
redis.call('ZADD', key, reserved_ms, member)
redis.call('PEXPIRE', key, window_ms + delay_ms + 1000)
return delay_ms
end
return 0
end
"#;
#[derive(Clone)]
pub struct RedisSlidingWindowStore {
conn: RedisConnection,
fallback: InMemorySlidingWindowStore,
count: u64,
window: Duration,
namespace: String,
}
impl RedisSlidingWindowStore {
pub fn new(conn: RedisConnection, count: u64, window: Duration) -> Self {
Self {
conn,
fallback: InMemorySlidingWindowStore::new(count, window),
count,
window,
namespace: "sliding_window".to_string(),
}
}
pub fn scope(mut self, id: &str) -> Self {
self.namespace = format!("sliding_window:{id}");
self
}
}
impl SlidingWindowStore for RedisSlidingWindowStore {
fn take(&self, key: &str) -> impl std::future::Future<Output = Option<Duration>> + Send {
let redis_key = self.conn.prefixed_key(&self.namespace, key);
let conn = self.conn.clone();
let count = self.count;
let window_ms = self.window.as_millis() as u64;
let fallback = self.fallback.clone();
let key = key.to_string();
async move {
let mgr = match conn.get_connection().await {
Ok(mgr) => mgr,
Err(e) => {
tracing::warn!(error = %e, "Redis sliding window connect failed, using in-memory fallback");
return fallback.take(&key).await;
}
};
let member = format!("{:x}:{:x}", rand::random::<u64>(), rand::random::<u64>());
let result: Result<i64, _> = ::redis::Script::new(SLIDING_WINDOW_LUA)
.key(&redis_key)
.arg(count)
.arg(window_ms)
.arg(&member)
.invoke_async(&mut mgr.clone())
.await;
match result {
Ok(delay_ms) if delay_ms > 0 => Some(Duration::from_millis(delay_ms as u64)),
Ok(_) => None,
Err(e) => {
tracing::warn!(error = %e, "Redis sliding window failed, using in-memory fallback");
fallback.take(&key).await
}
}
}
}
}
}
#[cfg(feature = "redis")]
pub use redis_impl::RedisSlidingWindowStore;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shared_state_preserves_active_keys_when_over_capacity() {
let mut state = SharedState {
windows: HashMap::new(),
count: 1,
window: Duration::from_secs(1),
max_keys: 2,
idle_ttl: Duration::from_secs(60),
next_cleanup: Instant::now() + CLEANUP_INTERVAL,
};
let _ = state.take("a");
let _ = state.take("b");
let _ = state.take("c");
assert!(state.windows.contains_key("a"));
assert!(state.windows.contains_key("b"));
assert!(state.windows.contains_key("c"));
}
#[test]
fn shared_state_evicts_idle_keys() {
let mut state = SharedState {
windows: HashMap::new(),
count: 1,
window: Duration::from_secs(1),
max_keys: 10,
idle_ttl: Duration::from_millis(1),
next_cleanup: Instant::now(),
};
let _ = state.take("a");
for v in state.windows.values_mut() {
v.last_seen = Instant::now() - Duration::from_secs(5);
}
state.next_cleanup = Instant::now();
let _ = state.take("b");
assert!(!state.windows.contains_key("a"));
}
#[test]
fn shared_state_preserves_until_window_expires() {
let mut state = SharedState {
windows: HashMap::new(),
count: 1,
window: Duration::from_secs(10),
max_keys: 10,
idle_ttl: Duration::from_millis(1),
next_cleanup: Instant::now(),
};
let _ = state.take("a");
state.windows.get_mut("a").unwrap().last_seen = Instant::now() - Duration::from_secs(5);
state.next_cleanup = Instant::now();
let _ = state.take("b");
assert!(state.windows.contains_key("a"));
}
#[test]
fn shared_state_does_not_evict_active_key_at_capacity() {
let mut state = SharedState {
windows: HashMap::new(),
count: 1,
window: Duration::from_secs(10),
max_keys: 1,
idle_ttl: Duration::from_secs(600),
next_cleanup: Instant::now() + CLEANUP_INTERVAL,
};
let _ = state.take("a");
let _ = state.take("b");
assert!(state.windows.contains_key("a"));
assert!(state.windows.contains_key("b"));
}
}