use crate::config::PluginConfig;
use crate::dns::ResponseCode;
use crate::plugin::{Context, Plugin};
use crate::{RegisterPlugin, Result};
use async_trait::async_trait;
use dashmap::DashMap;
use std::any::Any;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::warn;
#[derive(Debug)]
struct RateLimitEntry {
count: u32,
window_start: Instant,
}
#[derive(Debug, RegisterPlugin)]
pub struct RateLimitPlugin {
max_queries: u32,
window_secs: u64,
limits: Arc<DashMap<IpAddr, RateLimitEntry>>,
}
impl RateLimitPlugin {
pub fn new(max_queries: u32, window_secs: u64) -> Self {
Self {
max_queries,
window_secs,
limits: Arc::new(DashMap::new()),
}
}
fn should_limit(&self, client_ip: IpAddr) -> bool {
let now = Instant::now();
let window_duration = Duration::from_secs(self.window_secs);
let mut should_limit = false;
self.limits
.entry(client_ip)
.and_modify(|entry| {
if now.duration_since(entry.window_start) >= window_duration {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
if entry.count > self.max_queries {
should_limit = true;
}
}
})
.or_insert(RateLimitEntry {
count: 1,
window_start: now,
});
should_limit
}
pub fn cleanup(&self) {
let now = Instant::now();
let window_duration = Duration::from_secs(self.window_secs);
self.limits
.retain(|_, entry| now.duration_since(entry.window_start) < window_duration * 2);
}
}
#[async_trait]
impl Plugin for RateLimitPlugin {
async fn execute(&self, ctx: &mut Context) -> Result<()> {
let client_ip: IpAddr = match ctx.get_metadata::<IpAddr>("client_ip") {
Some(ip) => *ip,
None => {
return Ok(());
}
};
if self.should_limit(client_ip) {
warn!("Rate limit exceeded for IP: {}", client_ip);
let mut response = crate::dns::Message::new();
response.set_id(ctx.request().id());
response.set_response(true);
response.set_response_code(ResponseCode::Refused);
ctx.set_response(Some(response));
}
Ok(())
}
fn name(&self) -> &str {
"rate_limit"
}
fn priority(&self) -> i32 {
1000
}
fn as_any(&self) -> &dyn Any {
self
}
fn init(config: &PluginConfig) -> Result<Arc<dyn Plugin>> {
use serde_yaml::Value;
use std::sync::Arc;
let mut max_queries: u32 = 100;
let mut window_secs: u64 = 60;
let args = config.effective_args();
if let Some(v) = args.get("max_queries") {
match v {
Value::Number(n) => {
if let Some(u) = n.as_u64() {
max_queries = u as u32;
} else if let Some(i) = n.as_i64() {
max_queries = i as u32;
} else {
return Err(crate::Error::Config(
"Invalid 'max_queries' value".to_string(),
));
}
}
Value::String(s) => {
max_queries = s.parse::<u32>().map_err(|_| {
crate::Error::Config("Invalid 'max_queries' string value".to_string())
})?;
}
_ => {
return Err(crate::Error::Config(
"Invalid 'max_queries' type".to_string(),
));
}
}
}
if let Some(v) = args.get("window_secs") {
match v {
Value::Number(n) => {
if let Some(u) = n.as_u64() {
window_secs = u;
} else if let Some(i) = n.as_i64() {
window_secs = i as u64;
} else {
return Err(crate::Error::Config(
"Invalid 'window_secs' value".to_string(),
));
}
}
Value::String(s) => {
window_secs = s.parse::<u64>().map_err(|_| {
crate::Error::Config("Invalid 'window_secs' string value".to_string())
})?;
}
_ => {
return Err(crate::Error::Config(
"Invalid 'window_secs' type".to_string(),
));
}
}
}
Ok(Arc::new(RateLimitPlugin::new(max_queries, window_secs)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::Message;
#[test]
fn test_rate_limit_creation() {
let limiter = RateLimitPlugin::new(10, 60);
assert_eq!(limiter.max_queries, 10);
assert_eq!(limiter.window_secs, 60);
}
#[test]
fn test_rate_limit_allows_within_limit() {
let limiter = RateLimitPlugin::new(5, 60);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
for _ in 0..5 {
assert!(!limiter.should_limit(ip));
}
}
#[test]
fn test_rate_limit_blocks_over_limit() {
let limiter = RateLimitPlugin::new(3, 60);
let ip: IpAddr = "192.168.1.2".parse().unwrap();
for _ in 0..3 {
assert!(!limiter.should_limit(ip));
}
assert!(limiter.should_limit(ip));
}
#[test]
fn test_rate_limit_different_ips() {
let limiter = RateLimitPlugin::new(2, 60);
let ip1: IpAddr = "192.168.1.1".parse().unwrap();
let ip2: IpAddr = "192.168.1.2".parse().unwrap();
assert!(!limiter.should_limit(ip1));
assert!(!limiter.should_limit(ip2));
assert!(!limiter.should_limit(ip1));
assert!(!limiter.should_limit(ip2));
assert!(limiter.should_limit(ip1));
assert!(limiter.should_limit(ip2));
}
#[test]
fn test_cleanup() {
let limiter = RateLimitPlugin::new(10, 1);
let ip: IpAddr = "192.168.1.3".parse().unwrap();
limiter.should_limit(ip);
assert_eq!(limiter.limits.len(), 1);
limiter.cleanup();
assert_eq!(limiter.limits.len(), 1);
std::thread::sleep(std::time::Duration::from_secs(3));
limiter.cleanup();
assert_eq!(limiter.limits.len(), 0); }
#[tokio::test]
async fn test_rate_limit_plugin_execute() {
let limiter = RateLimitPlugin::new(2, 60);
let mut ctx = Context::new(Message::new());
ctx.set_metadata("client_ip", "192.168.1.1".parse::<IpAddr>().unwrap());
limiter.execute(&mut ctx).await.unwrap();
assert!(ctx.response().is_none());
limiter.execute(&mut ctx).await.unwrap();
assert!(ctx.response().is_none());
limiter.execute(&mut ctx).await.unwrap();
assert!(ctx.response().is_some());
assert_eq!(
ctx.response().unwrap().response_code(),
ResponseCode::Refused
);
}
}
#[cfg(test)]
mod builder_init_tests {
use super::*;
use serde_yaml::Mapping;
use serde_yaml::Value;
#[test]
fn test_init_from_config() {
let mut args_map = Mapping::new();
args_map.insert(
Value::String("max_queries".to_string()),
Value::Number(100.into()),
);
args_map.insert(
Value::String("window_secs".to_string()),
Value::Number(60.into()),
);
let cfg = crate::config::types::PluginConfig {
tag: None,
plugin_type: "rate_limit".to_string(),
args: Value::Mapping(args_map),
priority: 100,
config: std::collections::HashMap::new(),
};
let plugin = RateLimitPlugin::init(&cfg).expect("init");
assert_eq!(plugin.name(), "rate_limit");
if let Some(rl) = plugin.as_ref().as_any().downcast_ref::<RateLimitPlugin>() {
assert_eq!(rl.max_queries, 100);
assert_eq!(rl.window_secs, 60);
} else {
panic!("unexpected plugin type");
}
}
}