use std::{
net::IpAddr,
sync::{Arc, OnceLock},
time::SystemTime,
};
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use spacegate_kernel::{extension::OriginalIpAddr, helper_layers::function::Inner, BoxError, SgBody, SgRequestExt, SgResponseExt};
use crate::Plugin;
use spacegate_ext_redis::redis::Script;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct RateLimitPluginConfig {
pub max_request_number: Option<u64>,
pub time_window_ms: Option<u64>,
pub report_ext: Value,
}
#[derive(Debug, Clone)]
pub struct RateLimitPlugin {
pub max_request_number: u64,
pub time_window_ms: u64,
pub report_ext: Arc<Value>,
pub id: Arc<str>,
}
impl RateLimitPlugin {
pub fn report(&self, rising_edge: bool, original_ip_addr: IpAddr) -> RateLimitReport {
RateLimitReport {
rising_edge,
original_ip_addr,
plugin: self.clone(),
}
}
}
const DEFAULT_TIME_WINDOW_MS: u64 = 1000;
const DEFAULT_MAX_REQUEST_NUMBER: u64 = 100;
const CONF_LIMIT_KEY: &str = "sg:plugin:filter:limit:";
#[derive(Debug, Clone)]
pub struct RateLimitReport {
pub rising_edge: bool,
pub original_ip_addr: IpAddr,
pub plugin: RateLimitPlugin,
}
pub fn script() -> &'static Script {
static SCRIPT: OnceLock<Script> = OnceLock::new();
SCRIPT.get_or_init(|| Script::new(include_str!("./limit/script.lua")))
}
impl RateLimitPlugin {}
impl Plugin for RateLimitPlugin {
const CODE: &'static str = "limit";
fn meta() -> spacegate_model::PluginMetaData {
crate::plugin_meta!(
description: "Request rate limit plugin."
)
}
async fn call(&self, req: Request<SgBody>, inner: Inner) -> Result<Response<SgBody>, BoxError> {
let id = &self.id;
let ip = req.extract::<OriginalIpAddr>().to_canonical();
let mut conn = req.get_redis_client_by_gateway_name().ok_or("missing gateway name")?.get_conn().await;
const EXCEEDED: i32 = 0;
const RISING_EDGE: i32 = 1;
let result: i32 = script()
.key(format!("{CONF_LIMIT_KEY}{id}:{ip}"))
.key(format!("{CONF_LIMIT_KEY}{id}:{ip}_ts"))
.arg(self.max_request_number)
.arg(self.time_window_ms)
.arg(SystemTime::now().duration_since(std::time::UNIX_EPOCH).expect("invalid system time: before unix epoch").as_millis() as u64)
.invoke_async(&mut conn)
.await?;
if result == EXCEEDED || result == RISING_EDGE {
let mut response = Response::<SgBody>::with_code_message(StatusCode::TOO_MANY_REQUESTS, "[SG.Filter.Limit] too many requests");
response.extensions_mut().insert(self.report(result == RISING_EDGE, ip));
return Ok(response);
}
Ok(inner.call(req).await)
}
fn create(config: crate::PluginConfig) -> Result<Self, BoxError> {
let spec = serde_json::from_value::<RateLimitPluginConfig>(config.spec)?;
let id = config.id.to_string();
Ok(Self {
max_request_number: spec.max_request_number.unwrap_or(DEFAULT_MAX_REQUEST_NUMBER),
time_window_ms: spec.time_window_ms.unwrap_or(DEFAULT_TIME_WINDOW_MS),
report_ext: Arc::new(spec.report_ext),
id: Arc::from(id),
})
}
#[cfg(feature = "schema")]
fn schema_opt() -> Option<schemars::schema::RootSchema> {
use crate::PluginSchemaExt;
Some(Self::schema())
}
}
#[cfg(feature = "schema")]
crate::schema! { RateLimitPlugin, RateLimitPluginConfig }