use blvm_node::rpc::auth::{RpcAuthManager, RpcRateLimiter, UserId};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn test_rate_limiter_basic_functionality() {
let mut limiter = RpcRateLimiter::new(10, 5);
for i in 0..10 {
assert!(limiter.check_and_consume(), "Request {i} should be allowed");
assert_eq!(limiter.tokens_remaining(), 10 - i - 1);
}
assert!(!limiter.check_and_consume());
assert_eq!(limiter.tokens_remaining(), 0);
}
#[tokio::test]
async fn test_rate_limiter_refill_over_time() {
let mut limiter = RpcRateLimiter::new(10, 10);
for _ in 0..10 {
limiter.check_and_consume();
}
assert_eq!(limiter.tokens_remaining(), 0);
sleep(Duration::from_millis(1100)).await;
let tokens_before = limiter.tokens_remaining();
if tokens_before > 0 {
assert!(limiter.check_and_consume());
} else {
sleep(Duration::from_millis(200)).await;
assert!(limiter.tokens_remaining() > 0 || limiter.check_and_consume());
}
}
#[tokio::test]
async fn test_rate_limiter_burst_limit_cap() {
let mut limiter = RpcRateLimiter::new(5, 100);
for _ in 0..5 {
limiter.check_and_consume();
}
sleep(Duration::from_millis(100)).await;
let tokens = limiter.tokens_remaining();
assert!(
tokens <= 5,
"Tokens should not exceed burst limit of 5, got {tokens}"
);
}
#[tokio::test]
async fn test_auth_manager_rate_limit_per_user() {
let manager = RpcAuthManager::new(false);
manager.add_token("user1".to_string()).await.unwrap();
manager.add_token("user2".to_string()).await.unwrap();
let mut headers1 = hyper::HeaderMap::new();
headers1.insert("authorization", "Bearer user1".parse().unwrap());
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let result1 = manager.authenticate_request(&headers1, addr).await;
let user1 = result1.user_id.unwrap();
let mut headers2 = hyper::HeaderMap::new();
headers2.insert("authorization", "Bearer user2".parse().unwrap());
let result2 = manager.authenticate_request(&headers2, addr).await;
let user2 = result2.user_id.unwrap();
manager.set_user_rate_limit(&user1, 5, 2).await;
manager.set_user_rate_limit(&user2, 10, 5).await;
for i in 0..5 {
let allowed = manager.check_rate_limit(&user1).await;
assert!(allowed, "User1 request {i} should be allowed");
}
let allowed = manager.check_rate_limit(&user1).await;
assert!(!allowed, "User1 should be rate limited after burst");
let allowed = manager.check_rate_limit(&user2).await;
assert!(allowed, "User2 should still have tokens");
}
#[tokio::test]
async fn test_auth_manager_rate_limit_per_method() {
let manager = RpcAuthManager::new(false);
manager
.set_method_rate_limit("expensive_method", 2, 1)
.await;
manager.set_method_rate_limit("cheap_method", 10, 5).await;
assert!(manager.check_method_rate_limit("expensive_method").await);
assert!(manager.check_method_rate_limit("expensive_method").await);
assert!(!manager.check_method_rate_limit("expensive_method").await);
assert!(manager.check_method_rate_limit("cheap_method").await);
}
#[tokio::test]
async fn test_auth_manager_ip_rate_limiting() {
let manager = RpcAuthManager::new(false);
let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let addr2: SocketAddr = "192.168.1.1:8080".parse().unwrap();
let user1 = UserId::Ip(addr1);
let user2 = UserId::Ip(addr2);
for i in 0..50 {
let allowed1 = manager.check_rate_limit(&user1).await;
let allowed2 = manager.check_rate_limit(&user2).await;
if i < 50 {
assert!(
allowed1 || allowed2 || i > 0,
"At least one IP should have tokens initially"
);
}
}
}
#[tokio::test]
async fn test_auth_manager_rate_limit_recovery() {
let manager = RpcAuthManager::new(false);
let user = UserId::Ip("127.0.0.1:8080".parse().unwrap());
let mut exhausted = false;
for i in 0..60 {
if !manager.check_rate_limit(&user).await {
exhausted = true;
break;
}
if i >= 49 {
exhausted = true;
}
}
assert!(!manager.check_rate_limit(&user).await || exhausted);
sleep(Duration::from_millis(1100)).await;
sleep(Duration::from_millis(200)).await;
let allowed = manager.check_rate_limit(&user).await;
assert!(allowed, "Rate limit should recover after refill period");
}
#[tokio::test]
async fn test_auth_manager_concurrent_rate_limiting() {
use std::sync::Arc;
let manager = Arc::new(RpcAuthManager::new(false));
let user = UserId::Ip("127.0.0.1:8080".parse().unwrap());
let mut handles = vec![];
for _ in 0..20 {
let manager_clone = Arc::clone(&manager);
let user_clone = user.clone();
handles.push(tokio::spawn(async move {
manager_clone.check_rate_limit(&user_clone).await
}));
}
let mut allowed_count = 0;
for handle in handles {
if handle.await.unwrap() {
allowed_count += 1;
}
}
assert!(
allowed_count <= 50,
"Should not exceed IP burst limit of 50"
);
}
#[tokio::test]
async fn test_rate_limiter_zero_rate() {
let mut limiter = RpcRateLimiter::new(10, 0);
for _ in 0..10 {
assert!(limiter.check_and_consume());
}
sleep(Duration::from_millis(1100)).await;
assert_eq!(limiter.tokens_remaining(), 0);
assert!(!limiter.check_and_consume());
}
#[tokio::test]
#[ignore] async fn test_rate_limiter_high_rate() {
let mut limiter = RpcRateLimiter::new(10, 1000);
for _ in 0..10 {
limiter.check_and_consume();
}
sleep(Duration::from_millis(100)).await;
let had_tokens = limiter.check_and_consume();
let tokens = limiter.tokens_remaining();
assert!(tokens <= 10, "Should not exceed burst limit of 10");
assert!(
had_tokens || tokens >= 8,
"Should have refilled (had_tokens: {had_tokens}, remaining: {tokens})"
);
}