use std::{collections::HashMap, thread, time::Duration};
use redis::AsyncCommands;
use super::runtime;
use super::common::{redis_url, unique_prefix, key, key_gen};
use crate::common::{RateType, SuppressionFactorCacheMs};
use crate::hybrid::SyncIntervalMs;
use crate::{
HardLimitFactor, LocalRateLimiterOptions, RateGroupSizeMs, RateLimit, RateLimitDecision,
RateLimiter, RateLimiterOptions, RedisKey, RedisRateLimiterOptions, WindowSizeSeconds,
};
async fn build_limiter(
url: &str,
window_size_seconds: u64,
rate_group_size_ms: u64,
hard_limit_factor: f64,
suppression_factor_cache_ms: u64,
) -> (std::sync::Arc<RateLimiter>, RedisKey) {
let client = redis::Client::open(url).unwrap();
let cm = client.get_connection_manager().await.unwrap();
let prefix = unique_prefix();
let options = RateLimiterOptions {
local: LocalRateLimiterOptions {
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::try_from(hard_limit_factor).unwrap(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::try_from(
suppression_factor_cache_ms,
)
.unwrap(),
},
redis: RedisRateLimiterOptions {
connection_manager: cm,
prefix: Some(prefix.clone()),
window_size_seconds: WindowSizeSeconds::try_from(window_size_seconds).unwrap(),
rate_group_size_ms: RateGroupSizeMs::try_from(rate_group_size_ms).unwrap(),
hard_limit_factor: HardLimitFactor::try_from(hard_limit_factor).unwrap(),
suppression_factor_cache_ms: SuppressionFactorCacheMs::try_from(
suppression_factor_cache_ms,
)
.unwrap(),
sync_interval_ms: SyncIntervalMs::default(),
},
};
(std::sync::Arc::new(RateLimiter::new(options)), prefix)
}
fn redis_key(prefix: &RedisKey, user_key: &RedisKey, suffix: &str) -> String {
let kg = key_gen(prefix, RateType::Suppressed);
match suffix {
"h" => kg.get_hash_key(user_key),
"hd" => kg.get_hash_declined_key(user_key),
"a" => kg.get_active_keys(user_key),
"w" => kg.get_window_limit_key(user_key),
"t" => kg.get_total_count_key(user_key),
"d" => kg.get_total_declined_key(user_key),
"sf" => kg.get_suppression_factor_key(user_key),
_ => panic!("unknown suffix for suppressed rate type: {suffix}"),
}
}
fn active_entities_key(prefix: &RedisKey) -> String {
key_gen(prefix, RateType::Suppressed).get_active_entities_key()
}
#[test]
fn redis_state_suppressed_allowed_inc_sets_total_count_and_zero_declined() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000, 2.0, 100).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
let d = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 3)
.await
.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert_eq!(total, 3, "total count should be 3");
let declined: u64 = conn
.get(redis_key(&prefix, &k, "d"))
.await
.unwrap_or(0u64);
assert_eq!(declined, 0, "declined count should be 0");
let hash: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let hash_sum: u64 = hash.values().sum();
assert_eq!(hash_sum, 3, "hash sum should equal the increment");
let hash_d: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "hd")).await.unwrap();
assert!(hash_d.is_empty(), "declined hash should be empty");
});
}
#[test]
fn redis_state_suppressed_denied_calls_increment_both_total_and_declined() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000, 10.0, 1).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 100)
.await
.unwrap();
runtime::async_sleep(Duration::from_millis(10)).await;
let d = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
assert!(
matches!(
d,
RateLimitDecision::Suppressed {
suppression_factor,
is_allowed: false
} if suppression_factor == 1.0
),
"d: {d:?}"
);
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let declined: u64 = conn.get(redis_key(&prefix, &k, "d")).await.unwrap();
assert!(declined > 0, "declined count should be > 0 after a denial");
let hash_d: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "hd")).await.unwrap();
assert!(
!hash_d.is_empty(),
"declined hash must have entries after a denial"
);
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
assert!(
total >= declined,
"total ({total}) must be >= declined ({declined})"
);
});
}
#[test]
fn redis_state_suppressed_sf_cache_key_is_set_in_suppression_zone() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 100, 2.0, 1).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
for _ in 0..11 {
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
}
runtime::async_sleep(Duration::from_millis(10)).await;
let _ = rl
.redis()
.suppressed()
.get_suppression_factor(&k)
.await
.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let sf_raw: Option<String> = conn.get(redis_key(&prefix, &k, "sf")).await.unwrap();
assert!(
sf_raw.is_some(),
"suppression factor cache key should be set when in the suppression zone"
);
let sf: f64 = sf_raw.unwrap().parse().expect("sf should be a valid float");
assert!(
(0.0..=1.0).contains(&sf),
"cached suppression factor must be in [0, 1], got {sf}"
);
assert!(sf > 0.0, "suppression factor should be > 0 past the soft limit");
});
}
#[test]
fn redis_state_suppressed_sf_cache_key_has_ttl() {
let url = redis_url();
runtime::block_on(async {
let cache_ms = 500_u64;
let (rl, prefix) = build_limiter(&url, 10, 100, 2.0, cache_ms).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
for _ in 0..11 {
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
}
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let pttl: i64 = conn
.pttl(redis_key(&prefix, &k, "sf"))
.await
.unwrap();
assert!(
pttl > 0,
"suppression factor cache key should have a positive TTL (PX set), got {pttl}"
);
assert!(
pttl <= cache_ms as i64,
"TTL should be <= cache_ms={cache_ms}, got {pttl}"
);
});
}
#[test]
fn redis_state_suppressed_window_limit_key_equals_hard_limit() {
let url = redis_url();
runtime::block_on(async {
let window_size_seconds = 5_u64;
let rate_limit_value = 4f64;
let hard_limit_factor = 3.0_f64;
let expected_hard_limit = 60_u64;
let (rl, prefix) = build_limiter(&url, window_size_seconds, 1000, hard_limit_factor, 100).await;
let k = key("k");
let rate_limit = RateLimit::try_from(rate_limit_value).unwrap();
rl.redis().suppressed().inc(&k, &rate_limit, 1).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let stored_limit: u64 = conn.get(redis_key(&prefix, &k, "w")).await.unwrap();
assert_eq!(
stored_limit, expected_hard_limit,
"stored window limit should equal hard_limit"
);
});
}
#[test]
fn redis_state_suppressed_hash_sums_match_counter_keys() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 100, 10.0, 1).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 100)
.await
.unwrap();
runtime::async_sleep(Duration::from_millis(10)).await;
for _ in 0..5 {
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
}
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let declined: u64 = conn
.get(redis_key(&prefix, &k, "d"))
.await
.unwrap_or(0u64);
let hash: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let hash_d: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "hd")).await.unwrap();
let hash_sum: u64 = hash.values().sum();
let hash_d_sum: u64 = hash_d.values().sum();
assert_eq!(
hash_sum, total,
"hash sum ({hash_sum}) must equal total count ({total})"
);
assert_eq!(
hash_d_sum, declined,
"declined hash sum ({hash_d_sum}) must equal declined count ({declined})"
);
});
}
#[test]
fn redis_state_suppressed_evicts_expired_buckets_from_both_hashes() {
let url = redis_url();
runtime::block_on(async {
let window_size_seconds = 1_u64;
let (rl, prefix) = build_limiter(&url, window_size_seconds, 1000, 10.0, 1).await;
let k = key("k");
let rate_limit = RateLimit::try_from(1f64).unwrap();
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 10)
.await
.unwrap();
runtime::async_sleep(Duration::from_millis(5)).await;
let _ = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
thread::sleep(Duration::from_millis(window_size_seconds * 1000 + 50));
runtime::async_sleep(Duration::from_millis(10)).await;
let d = rl
.redis()
.suppressed()
.inc(&k, &rate_limit, 1)
.await
.unwrap();
assert!(matches!(d, RateLimitDecision::Allowed), "d: {d:?}");
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: u64 = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash: HashMap<String, u64> =
conn.hgetall(redis_key(&prefix, &k, "h")).await.unwrap();
let active_count: u64 = conn
.zcard(redis_key(&prefix, &k, "a"))
.await
.unwrap();
assert_eq!(
total, 1,
"total count must be 1 after eviction (only the fresh increment remains)"
);
assert_eq!(
hash.len(),
1,
"hash must have exactly one bucket after eviction"
);
assert_eq!(
active_count, 1,
"active sorted set must have one entry after eviction"
);
});
}
#[test]
fn redis_state_suppressed_get_factor_on_unknown_key_writes_no_state() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000, 2.0, 100).await;
let k = key("brand_new");
let sf = rl
.redis()
.suppressed()
.get_suppression_factor(&k)
.await
.unwrap();
assert!((sf - 0.0).abs() < 1e-12, "sf: {sf}");
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total: Option<u64> = conn.get(redis_key(&prefix, &k, "t")).await.unwrap();
let hash_len: u64 = conn
.hlen(redis_key(&prefix, &k, "h"))
.await
.unwrap();
assert!(
total.is_none(),
"total count key should not exist for an unknown key"
);
assert_eq!(
hash_len, 0,
"hash should be empty for an unknown key"
);
});
}
#[test]
fn redis_state_suppressed_per_key_state_is_independent() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000, 2.0, 100).await;
let a = key("a");
let b = key("b");
let rate_limit = RateLimit::try_from(1f64).unwrap();
rl.redis().suppressed().inc(&a, &rate_limit, 4).await.unwrap();
rl.redis().suppressed().inc(&b, &rate_limit, 9).await.unwrap();
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let total_a: u64 = conn.get(redis_key(&prefix, &a, "t")).await.unwrap();
let total_b: u64 = conn.get(redis_key(&prefix, &b, "t")).await.unwrap();
assert_eq!(total_a, 4, "total for key a should be 4");
assert_eq!(total_b, 9, "total for key b should be 9");
});
}
#[test]
fn redis_state_suppressed_active_entities_updated_on_inc() {
let url = redis_url();
runtime::block_on(async {
let (rl, prefix) = build_limiter(&url, 10, 1000, 2.0, 100).await;
let k = key("entity");
let rate_limit = RateLimit::try_from(1f64).unwrap();
let ae_key = active_entities_key(&prefix);
let mut conn = redis::Client::open(url.as_str())
.unwrap()
.get_multiplexed_async_connection()
.await
.unwrap();
let score_before: Option<f64> = conn.zscore(&ae_key, &**k).await.unwrap();
assert!(score_before.is_none(), "key should not be in active_entities before inc");
rl.redis().suppressed().inc(&k, &rate_limit, 1).await.unwrap();
let score_after: Option<f64> = conn.zscore(&ae_key, &**k).await.unwrap();
assert!(
score_after.is_some(),
"key should be in active_entities after inc"
);
});
}