use std::{collections::HashMap, time::Duration};
use redis::AsyncCommands;
use super::runtime;
use super::common::{redis_url, unique_prefix, key, key_gen, wait_for_hybrid_sync};
use crate::common::{RateType, SuppressionFactorCacheMs};
use crate::hybrid::SyncIntervalMs;
use crate::{
HardLimitFactor, LocalRateLimiterOptions, RateGroupSizeMs, RateLimit, RateLimitDecision,
RateLimiter, RateLimiterOptions, RedisKey, RedisRateLimiterOptions, WindowSizeSeconds,
};
async fn build_limiter(
url: &str,
window_size_seconds: u64,
rate_group_size_ms: u64,
sync_interval_ms: u64,
prefix: RedisKey,
) -> std::sync::Arc<RateLimiter> {
let client = redis::Client::open(url).unwrap();
let cm = client.get_connection_manager().await.unwrap();
let options = RateLimiterOptions {
local: LocalRateLimiterOptions {
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::default(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::default(),
},
redis: RedisRateLimiterOptions {
connection_manager: cm,
prefix: Some(prefix),
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::default(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::default(),
sync_interval_ms: SyncIntervalMs::try_from(sync_interval_ms).unwrap(),
},
};
std::sync::Arc::new(RateLimiter::new(options))
}
fn redis_key(prefix: &RedisKey, user_key: &RedisKey, suffix: &str) -> String {
let kg = key_gen(prefix, RateType::HybridAbsolute);
match suffix {
"h" => kg.get_hash_key(user_key),
"a" => kg.get_active_keys(user_key),
"w" => kg.get_window_limit_key(user_key),
"t" => kg.get_total_count_key(user_key),
_ => panic!("unknown suffix for hybrid_absolute rate type: {suffix}"),
}
}
#[test]
fn redis_state_hybrid_absolute_no_redis_keys_before_overflow() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 2000_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
for _ in 0..5 {
let d = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
}
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: Option<u64> = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
total.is_none(),
"total count key should not exist before the local budget overflows"
);
let hash_len: u64 = conn.hlen(redis_key(&prefix, &k, "h")).await.unwrap();
assert_eq!(hash_len, 0, "hash should be empty before overflow");
});
}
#[test]
fn redis_state_hybrid_absolute_commit_writes_total_count_after_overflow() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 25_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let d = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
}
let d_overflow = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(
matches!(d_overflow, RateLimitDecision::Rejected { .. }),
"d_overflow: {d_overflow:?}"
);
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
total > 0,
"total count should be > 0 after overflow commit, got {total}"
);
assert!(
total <= cap,
"total count ({total}) should not exceed capacity ({cap})"
);
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
assert!(!hash.is_empty(), "hash must have at least one bucket after commit");
let active_count: u64 = conn.zcard(redis_key(&prefix, &k, "a")).await.unwrap();
assert!(
active_count > 0,
"active sorted set must be non-empty after commit"
);
});
}
#[test]
fn redis_state_hybrid_absolute_window_limit_key_is_set_after_commit() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 2_u64;
let sync_interval_ms = 25_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(3f64).unwrap();
let expected_window_limit = 6_u64;
for _ in 0..expected_window_limit {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let stored_limit: u64 = conn.get(redis_key(&prefix, &k, "w")).await.unwrap();
assert_eq!(
stored_limit, expected_window_limit,
"stored window limit should equal capacity"
);
});
}
#[test]
fn redis_state_hybrid_absolute_committed_state_is_visible_to_another_instance() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 25_u64;
let rl_a = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl_a.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl_a.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
total > 0,
"total must be visible in Redis after A commits, got {total}"
);
});
}
#[test]
fn redis_state_hybrid_absolute_hash_sum_matches_total_count_after_commit() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 25_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash: HashMap<String, u64> = conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let hash_sum: u64 = hash.values().sum();
assert_eq!(
hash_sum, total,
"hash sum ({hash_sum}) must equal total count ({total}) after commit"
);
});
}
#[test]
fn redis_state_hybrid_absolute_evicts_expired_buckets_on_next_commit() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 25_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
runtime::async_sleep(Duration::from_millis(window_size_seconds * 1000 + 100)).await;
let d = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d after expiry: {d:?}");
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
total <= cap,
"total count ({total}) after window expiry must not exceed capacity ({cap})"
);
});
}
#[test]
fn redis_state_hybrid_absolute_different_prefixes_are_isolated() {
let url = redis_url();
runtime::block_on(async {
let prefix_a = unique_prefix();
let prefix_b = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 25_u64;
let rl_a = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix_a.clone()).await;
let rl_b = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix_b.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl_a.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl_a.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_b: Option<u64> = conn.get(redis_key(&prefix_b, &k, "t")).await.unwrap();
assert!(
total_b.is_none(),
"prefix B should not have a total count after prefix A's commit"
);
let d_b = rl_b.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(matches!(d_b, RateLimitDecision::Allowed), "d_b: {d_b:?}");
});
}
#[test]
fn redis_state_hybrid_absolute_active_sorted_set_scores_are_ordered() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 10_u64;
let rate_group_size_ms = 100_u64;
let sync_interval_ms = 50_u64;
let rl = build_limiter(&url, window_size_seconds, rate_group_size_ms, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(10f64).unwrap();
for _ in 0..10 {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
runtime::async_sleep(Duration::from_millis(200)).await;
let rl2 = build_limiter(&url, window_size_seconds, rate_group_size_ms, sync_interval_ms, prefix.clone()).await;
for _ in 0..10 {
let _ = rl2.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl2.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let members_with_scores: Vec<(String, f64)> = conn
.zrange_withscores(redis_key(&prefix, &k, "a"), 0isize, -1isize)
.await
.unwrap();
assert!(
!members_with_scores.is_empty(),
"active sorted set should not be empty after commits"
);
let scores: Vec<f64> = members_with_scores.iter().map(|(_, s)| *s).collect();
for i in 1..scores.len() {
assert!(
scores[i] >= scores[i - 1],
"scores must be non-decreasing: {scores:?}"
);
}
});
}
#[test]
fn redis_state_hybrid_absolute_cleanup_removes_all_redis_keys_for_stale_entity() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 5_u64;
let sync_interval_ms = 25_u64;
let stale_after_ms = 150_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
let active_entities_key = key_gen(&prefix, RateType::HybridAbsolute).get_active_entities_key();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let kg = key_gen(&prefix, RateType::HybridAbsolute);
for entity_key in kg.get_all_entity_keys(&k) {
if entity_key == kg.get_hash_key(&k)
|| entity_key == kg.get_active_keys(&k)
|| entity_key == kg.get_window_limit_key(&k)
|| entity_key == kg.get_total_count_key(&k)
{
let exists: bool = conn.exists(&entity_key).await.unwrap();
assert!(exists, "key {entity_key} must exist before cleanup");
}
}
let score: Option<f64> = conn.zscore(&active_entities_key, k.as_str()).await.unwrap();
assert!(score.is_some(), "entity must be in active_entities before cleanup");
runtime::async_sleep(Duration::from_millis(stale_after_ms + 50)).await;
rl.hybrid().absolute().cleanup(stale_after_ms).await.unwrap();
for entity_key in kg.get_all_entity_keys(&k) {
let exists: bool = conn.exists(&entity_key).await.unwrap();
assert!(!exists, "key {entity_key} must be deleted after cleanup");
}
let score_after: Option<f64> = conn.zscore(&active_entities_key, k.as_str()).await.unwrap();
assert!(score_after.is_none(), "entity must be removed from active_entities after cleanup");
});
}
#[test]
fn redis_state_hybrid_absolute_cleanup_does_not_remove_active_entity() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 5_u64;
let sync_interval_ms = 25_u64;
let stale_after_ms = 5_000_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
rl.hybrid().absolute().cleanup(stale_after_ms).await.unwrap();
let active_entities_key = key_gen(&prefix, RateType::HybridAbsolute).get_active_entities_key();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let t_exists: bool = conn.exists(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(t_exists, "total count key must still exist for active entity");
let score: Option<f64> = conn.zscore(&active_entities_key, k.as_str()).await.unwrap();
assert!(score.is_some(), "active entity must remain in active_entities after cleanup");
});
}
#[test]
fn redis_state_hybrid_absolute_cleanup_allows_fresh_requests_after_cleanup() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 5_u64;
let sync_interval_ms = 25_u64;
let stale_after_ms = 150_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(2f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let rejected = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(
matches!(rejected, RateLimitDecision::Rejected { .. }),
"expected Rejected after overflow, got {rejected:?}"
);
wait_for_hybrid_sync(sync_interval_ms).await;
runtime::async_sleep(Duration::from_millis(stale_after_ms + 50)).await;
rl.hybrid().absolute().cleanup(stale_after_ms).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let t_exists: bool = conn.exists(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(!t_exists, "total count key must be deleted after cleanup");
let decision = rl.hybrid().absolute().inc(&k, &rate_limit, 1).await.unwrap();
assert!(
matches!(decision, RateLimitDecision::Allowed),
"expected Allowed after cleanup but got {decision:?}"
);
});
}
#[test]
fn redis_state_hybrid_absolute_cleanup_multiple_entities_mixed() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 5_u64;
let sync_interval_ms = 25_u64;
let stale_after_ms = 150_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let stale = key("stale_user");
let active = key("active_user");
let rate_limit = RateLimit::try_from(2f64).unwrap();
let cap = (window_size_seconds as f64 * *rate_limit) as u64;
for _ in 0..cap {
let _ = rl.hybrid().absolute().inc(&stale, &rate_limit, 1).await.unwrap();
}
let _ = rl.hybrid().absolute().inc(&stale, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
runtime::async_sleep(Duration::from_millis(stale_after_ms + 50)).await;
let rl2 = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
for _ in 0..cap {
let _ = rl2.hybrid().absolute().inc(&active, &rate_limit, 1).await.unwrap();
}
let _ = rl2.hybrid().absolute().inc(&active, &rate_limit, 1).await.unwrap();
wait_for_hybrid_sync(sync_interval_ms).await;
rl2.hybrid().absolute().cleanup(stale_after_ms).await.unwrap();
let active_entities_key = key_gen(&prefix, RateType::HybridAbsolute).get_active_entities_key();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let stale_t: bool = conn.exists(redis_key(&prefix, &stale, "t")).await.unwrap();
assert!(!stale_t, "stale_user total count key must be deleted");
let stale_score: Option<f64> = conn.zscore(&active_entities_key, stale.as_str()).await.unwrap();
assert!(stale_score.is_none(), "stale_user must be removed from active_entities");
let active_t: bool = conn.exists(redis_key(&prefix, &active, "t")).await.unwrap();
assert!(active_t, "active_user total count key must still exist");
let active_score: Option<f64> = conn.zscore(&active_entities_key, active.as_str()).await.unwrap();
assert!(active_score.is_some(), "active_user must remain in active_entities");
});
}
#[test]
fn redis_state_hybrid_absolute_redis_absolute_keys_do_not_contaminate_hybrid_keyspace() {
let url = redis_url();
runtime::block_on(async {
let prefix = unique_prefix();
let window_size_seconds = 1_u64;
let sync_interval_ms = 2000_u64;
let rl = build_limiter(&url, window_size_seconds, 1000, sync_interval_ms, prefix.clone()).await;
let k = key("k");
let rate_limit = RateLimit::try_from(5f64).unwrap();
for _ in 0..5 {
let _ = rl.redis().absolute().inc(&k, &rate_limit, 1).await.unwrap();
}
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let hybrid_total: Option<u64> = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
hybrid_total.is_none(),
"hybrid_absolute keyspace must not be contaminated by redis absolute writes"
);
});
}