use std::{collections::HashMap, thread, time::Duration};
use redis::AsyncCommands;
use super::common::{key, key_gen, redis_url, unique_prefix};
use super::runtime;
use crate::common::{RateType, SuppressionFactorCacheMs};
use crate::hybrid::SyncIntervalMs;
use crate::{
HardLimitFactor, LocalRateLimiterOptions, RateGroupSizeMs, RateLimit, RateLimitDecision,
RateLimiter, RateLimiterOptions, RedisKey, RedisRateLimiterOptions, WindowSizeSeconds,
};
async fn build_limiter(
url: &str,
window_size_seconds: u64,
rate_group_size_ms: u64,
) -> (std::sync::Arc<RateLimiter>, RedisKey) {
let client = redis::Client::open(url).unwrap();
let cm = client.get_connection_manager().await.unwrap();
let prefix = unique_prefix();
let options = RateLimiterOptions {
local: LocalRateLimiterOptions {
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::default(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::default(),
},
redis: RedisRateLimiterOptions {
connection_manager: cm,
prefix: Some(prefix.clone()),
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::default(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::default(),
sync_interval_ms: SyncIntervalMs::default(),
},
};
(std::sync::Arc::new(RateLimiter::new(options)), prefix)
}
fn redis_key(prefix: &RedisKey, user_key: &RedisKey, suffix: &str) -> String {
let kg = key_gen(prefix, RateType::Absolute);
match suffix {
"h" => kg.get_hash_key(user_key),
"a" => kg.get_active_keys(user_key),
"w" => kg.get_window_limit_key(user_key),
"t" => kg.get_total_count_key(user_key),
_ => panic!("unknown suffix for absolute rate type: {suffix}"),
}
}
fn active_entities_key(prefix: &RedisKey) -> String {
key_gen(prefix, RateType::Absolute).get_active_entities_key()
}
#[test]
fn redis_state_after_single_allowed_inc() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let d = rl.redis().absolute().inc(&k, &rate_limit, 3).await.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total, 3, "total count should be 3");
let window_limit: u64 = conn.get(redis_key(&prefix, &k, "w")).await.unwrap();
assert_eq!(window_limit, 50, "window limit should be 50");
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
assert_eq!(hash.len(), 1, "hash should have exactly one bucket");
let bucket_count: u64 = *hash.values().next().unwrap();
assert_eq!(bucket_count, 3, "bucket count should be 3");
let active_count: u64 = conn.zcard(redis_key(&prefix, &k, "a")).await.unwrap();
assert_eq!(active_count, 1, "active sorted set should have one member");
let entity_score: Option<f64> = conn
.zscore(active_entities_key(&prefix), &**k)
.await
.unwrap();
assert!(entity_score.is_some(), "key should be in active_entities");
});
}
#[test]
fn redis_state_coalesces_increments_within_rate_group() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 2000).await;
let k = key("k");
let rate_limit = RateLimit::try_from(10f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 2).await.unwrap();
thread::sleep(Duration::from_millis(50));
rl.redis().absolute().inc(&k, &rate_limit, 3).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total, 5, "total count should be 5");
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
assert_eq!(hash.len(), 1, "coalesced — hash should have one bucket");
let bucket_count: u64 = *hash.values().next().unwrap();
assert_eq!(bucket_count, 5, "coalesced bucket should hold 5");
let active_count: u64 = conn.zcard(redis_key(&prefix, &k, "a")).await.unwrap();
assert_eq!(active_count, 1, "active sorted set should have one member");
});
}
#[test]
fn redis_state_creates_distinct_buckets_across_rate_groups() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 100).await;
let k = key("k");
let rate_limit = RateLimit::try_from(10f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
thread::sleep(Duration::from_millis(150));
rl.redis().absolute().inc(&k, &rate_limit, 2).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total, 3);
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
assert_eq!(hash.len(), 2, "two buckets should exist: {hash:?}");
let active_count: u64 = conn.zcard(redis_key(&prefix, &k, "a")).await.unwrap();
assert_eq!(active_count, 2, "active sorted set should have two members");
});
}
#[test]
fn redis_state_rejected_inc_does_not_mutate_state() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 1, 1000).await;
let k = key("k");
let rate_limit = RateLimit::try_from(2f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 2).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_before: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash_before: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let d = rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d, RateLimitDecision::Rejected { .. }), "d: {d:?}");
let total_after: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash_after: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
assert_eq!(
total_before, total_after,
"total count must not change on rejection"
);
assert_eq!(hash_before, hash_after, "hash must not change on rejection");
});
}
#[test]
fn redis_state_evicts_expired_buckets_after_window() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 1, 200).await;
let k = key("k");
let rate_limit = RateLimit::try_from(3f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 3).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_before: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total_before, 3, "at capacity before expiry");
thread::sleep(Duration::from_millis(1100));
let d = rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let total_after: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let active_count: u64 = conn.zcard(redis_key(&prefix, &k, "a")).await.unwrap();
assert_eq!(
total_after, 1,
"total count must reflect only the new increment after eviction"
);
assert_eq!(
hash.len(),
1,
"hash must contain only the new bucket after eviction"
);
assert_eq!(
active_count, 1,
"active sorted set must contain only the new bucket after eviction"
);
});
}
#[test]
fn redis_state_is_allowed_evicts_expired_buckets() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 1, 1000).await;
let k = key("k");
let rate_limit = RateLimit::try_from(2f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 2).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_before: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total_before, 2, "total should be 2 before expiry");
let d_before = rl.redis().absolute().is_allowed(&k).await.unwrap();
assert!(
matches!(d_before, RateLimitDecision::Rejected { .. }),
"should be rejected at capacity, got: {d_before:?}"
);
thread::sleep(Duration::from_millis(1300));
let d_after = rl.redis().absolute().is_allowed(&k).await.unwrap();
assert!(
matches!(d_after, RateLimitDecision::Allowed),
"should be allowed after window expiry, got: {d_after:?}"
);
});
}
#[test]
fn redis_state_per_key_state_is_independent() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000).await;
let a = key("a");
let b = key("b");
let rate_limit = RateLimit::try_from(5f64).unwrap();
rl.redis().absolute().inc(&a, &rate_limit, 3).await.unwrap();
rl.redis().absolute().inc(&b, &rate_limit, 7).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_a: u64 = conn.get(redis_key(&prefix, &a, "t")).await.unwrap();
let total_b: u64 = conn.get(redis_key(&prefix, &b, "t")).await.unwrap();
assert_eq!(total_a, 3, "total for key a should be 3");
assert_eq!(total_b, 7, "total for key b should be 7");
let hash_a: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &a, "h")).await.unwrap();
let hash_b: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &b, "h")).await.unwrap();
let sum_a: u64 = hash_a.values().sum();
let sum_b: u64 = hash_b.values().sum();
assert_eq!(sum_a, 3, "hash sum for key a should be 3");
assert_eq!(sum_b, 7, "hash sum for key b should be 7");
});
}
#[test]
fn redis_state_hash_sum_matches_total_count() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 100).await;
let k = key("k");
let rate_limit = RateLimit::try_from(100f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 5).await.unwrap();
thread::sleep(Duration::from_millis(150));
rl.redis().absolute().inc(&k, &rate_limit, 3).await.unwrap();
thread::sleep(Duration::from_millis(150));
rl.redis().absolute().inc(&k, &rate_limit, 7).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let hash_sum: u64 = hash.values().sum();
assert_eq!(
hash_sum, total,
"hash sum ({hash_sum}) must equal total count ({total})"
);
assert_eq!(total, 15, "total count should be 15");
});
}
#[test]
fn redis_state_active_sorted_set_scores_are_ordered() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 100).await;
let k = key("k");
let rate_limit = RateLimit::try_from(100f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
thread::sleep(Duration::from_millis(150));
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
thread::sleep(Duration::from_millis(150));
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let members_with_scores: Vec<(String, f64)> = conn
.zrange_withscores(redis_key(&prefix, &k, "a"), 0isize, -1isize)
.await
.unwrap();
assert_eq!(members_with_scores.len(), 3, "should have 3 buckets");
let scores: Vec<f64> = members_with_scores.iter().map(|(_, s)| *s).collect();
for i in 1..scores.len() {
assert!(
scores[i] >= scores[i - 1],
"scores must be non-decreasing: {scores:?}"
);
}
});
}
#[test]
fn redis_state_active_entities_updated_on_inc() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000).await;
let k = key("myentity");
let rate_limit = RateLimit::try_from(10f64).unwrap();
let ae_key = active_entities_key(&prefix);
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let score_before: Option<f64> = conn.zscore(&ae_key, &**k).await.unwrap();
assert!(
score_before.is_none(),
"key should not be in active_entities before inc"
);
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
let score_after: Option<f64> = conn.zscore(&ae_key, &**k).await.unwrap();
assert!(
score_after.is_some(),
"key should be in active_entities after inc"
);
});
}
#[test]
fn redis_state_window_limit_key_has_ttl() {
let url = redis_url();
runtime::block_on(async {
let window_size_seconds = 5_u64;
let (rl, prefix) = build_limiter(&url, window_size_seconds, 1000).await;
let k = key("k");
let rate_limit = RateLimit::try_from(10f64).unwrap();
rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let ttl: i64 = conn.ttl(redis_key(&prefix, &k, "w")).await.unwrap();
assert!(
ttl > 0,
"window limit key should have a positive TTL (EXPIRE was called), got {ttl}"
);
assert!(
ttl <= window_size_seconds as i64,
"TTL should be <= window_size_seconds={window_size_seconds}, got {ttl}"
);
});
}