use crate::{
common::config::env_loader,
engine::interfaces::{
GenericMiddleware, Middleware, MiddlewareOutput, ParamDef, ParamType, Plugin, ResolvedInputs,
},
};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use dashmap::DashMap;
use fancy_log::{LogLevel, log};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{any::Any, borrow::Cow, sync::Arc, time::Duration};
const ENTRY_OVERHEAD: usize = 92;
static SEC_POOL: Lazy<Arc<DashMap<String, u32>>> = Lazy::new(|| {
let map = Arc::new(DashMap::new());
let map_clone = map.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
if !map_clone.is_empty() {
map_clone.clear();
SEC_POOL_USAGE.store(0, Ordering::Relaxed);
}
}
});
map
});
static SEC_POOL_USAGE: AtomicUsize = AtomicUsize::new(0);
static MIN_POOL: Lazy<Arc<DashMap<String, u32>>> = Lazy::new(|| {
let map = Arc::new(DashMap::new());
let map_clone = map.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
if !map_clone.is_empty() {
map_clone.clear();
MIN_POOL_USAGE.store(0, Ordering::Relaxed);
}
}
});
map
});
static MIN_POOL_USAGE: AtomicUsize = AtomicUsize::new(0);
fn ensure_space(map: &DashMap<String, u32>, usage_counter: &AtomicUsize) {
let max_mem_str = env_loader::get_env("MAX_LIMITER_MEMORY", "4194304".to_owned()); let max_mem = max_mem_str.parse::<usize>().unwrap_or(4_194_304);
let current_usage = usage_counter.load(Ordering::Relaxed);
if current_usage > max_mem {
log(
LogLevel::Warn,
&format!(
"Rate limiter memory limit exceeded ({current_usage} > {max_mem} bytes). Pruning 10% of keys to self-preserve."
),
);
let items_to_remove = (map.len() as f64 * 0.1).ceil() as usize;
let keys_to_remove: Vec<String> = map
.iter()
.take(items_to_remove)
.map(|kv| kv.key().clone())
.collect();
for k in keys_to_remove {
if map.remove(&k).is_some() {
let entry_size = ENTRY_OVERHEAD + k.len();
usage_counter.fetch_sub(entry_size, Ordering::Relaxed);
}
}
}
}
fn check_key_length(key: &str) -> bool {
let max_len_str = env_loader::get_env("RATELIMIT_KEY_MAX_LEN", "256".to_owned());
let max_len = max_len_str.parse::<usize>().unwrap_or(256);
key.len() <= max_len
}
pub struct KeywordRateLimitSecPlugin;
impl Plugin for KeywordRateLimitSecPlugin {
fn name(&self) -> &'static str {
"internal.common.ratelimit.sec"
}
fn params(&self) -> Vec<ParamDef> {
vec![
ParamDef {
name: "key".into(),
required: true,
param_type: ParamType::String,
},
ParamDef {
name: "limit".into(),
required: true,
param_type: ParamType::Integer,
},
]
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_middleware(&self) -> Option<&dyn Middleware> {
Some(self)
}
fn as_generic_middleware(&self) -> Option<&dyn GenericMiddleware> {
Some(self)
}
}
#[async_trait]
impl GenericMiddleware for KeywordRateLimitSecPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
vec!["true".into(), "false".into()]
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
let key = inputs
.get("key")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Input 'key' missing"))?;
let limit = inputs
.get("limit")
.and_then(Value::as_u64)
.ok_or_else(|| anyhow!("Input 'limit' missing"))? as u32;
if !check_key_length(key) {
return Ok(MiddlewareOutput {
branch: "false".into(),
store: None,
});
}
let pool = &*SEC_POOL;
let current_count = if let Some(mut entry) = pool.get_mut(key) {
*entry += 1;
*entry
} else {
ensure_space(pool, &SEC_POOL_USAGE);
let entry_size = ENTRY_OVERHEAD + key.len();
SEC_POOL_USAGE.fetch_add(entry_size, Ordering::Relaxed);
pool.insert(key.to_owned(), 1);
1
};
let branch = if current_count <= limit {
"true"
} else {
"false"
};
Ok(MiddlewareOutput {
branch: branch.into(),
store: None,
})
}
}
#[async_trait]
impl Middleware for KeywordRateLimitSecPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
<Self as GenericMiddleware>::output(self)
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
<Self as GenericMiddleware>::execute(self, inputs).await
}
}
pub struct KeywordRateLimitMinPlugin;
impl Plugin for KeywordRateLimitMinPlugin {
fn name(&self) -> &'static str {
"internal.common.ratelimit.min"
}
fn params(&self) -> Vec<ParamDef> {
vec![
ParamDef {
name: "key".into(),
required: true,
param_type: ParamType::String,
},
ParamDef {
name: "limit".into(),
required: true,
param_type: ParamType::Integer,
},
]
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_middleware(&self) -> Option<&dyn Middleware> {
Some(self)
}
fn as_generic_middleware(&self) -> Option<&dyn GenericMiddleware> {
Some(self)
}
}
#[async_trait]
impl GenericMiddleware for KeywordRateLimitMinPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
vec!["true".into(), "false".into()]
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
let key = inputs
.get("key")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Input 'key' missing"))?;
let limit = inputs
.get("limit")
.and_then(Value::as_u64)
.ok_or_else(|| anyhow!("Input 'limit' missing"))? as u32;
if !check_key_length(key) {
return Ok(MiddlewareOutput {
branch: "false".into(),
store: None,
});
}
let pool = &*MIN_POOL;
let current_count = if let Some(mut entry) = pool.get_mut(key) {
*entry += 1;
*entry
} else {
ensure_space(pool, &MIN_POOL_USAGE);
let entry_size = ENTRY_OVERHEAD + key.len();
MIN_POOL_USAGE.fetch_add(entry_size, Ordering::Relaxed);
pool.insert(key.to_owned(), 1);
1
};
let branch = if current_count <= limit {
"true"
} else {
"false"
};
Ok(MiddlewareOutput {
branch: branch.into(),
store: None,
})
}
}
#[async_trait]
impl Middleware for KeywordRateLimitMinPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
<Self as GenericMiddleware>::output(self)
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
<Self as GenericMiddleware>::execute(self, inputs).await
}
}