use serde::{Deserialize, Serialize};
use worker::Env;
#[derive(Clone)]
pub struct DurableObjectRateLimiter {
namespace: String,
env: Option<Env>,
config: RateLimitConfig,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub limit: u64,
pub window_ms: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
limit: 100, window_ms: 60_000, }
}
}
impl RateLimitConfig {
pub fn new(limit: u64, window_ms: u64) -> Self {
Self { limit, window_ms }
}
pub fn per_second(limit: u64) -> Self {
Self::new(limit, 1_000)
}
pub fn per_minute(limit: u64) -> Self {
Self::new(limit, 60_000)
}
pub fn per_hour(limit: u64) -> Self {
Self::new(limit, 3_600_000)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u64,
pub limit: u64,
pub reset_ms: u64,
pub retry_after_ms: Option<u64>,
}
impl DurableObjectRateLimiter {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
env: None,
config: RateLimitConfig::default(),
}
}
pub fn from_env(env: &Env, binding: &str) -> worker::Result<Self> {
let _ = env.durable_object(binding)?;
Ok(Self {
namespace: binding.to_string(),
env: Some(env.clone()),
config: RateLimitConfig::default(),
})
}
pub fn with_env(mut self, env: Env) -> Self {
self.env = Some(env);
self
}
pub fn with_config(mut self, config: RateLimitConfig) -> Self {
self.config = config;
self
}
pub fn with_limit(mut self, limit: u64) -> Self {
self.config.limit = limit;
self
}
pub fn with_window_ms(mut self, window_ms: u64) -> Self {
self.config.window_ms = window_ms;
self
}
pub async fn check(&self, client_id: &str) -> Result<RateLimitResult, RateLimitError> {
#[derive(Serialize)]
struct CheckRequest<'a> {
limit: u64,
window_ms: u64,
record: bool,
client_id: &'a str,
}
let request = CheckRequest {
limit: self.config.limit,
window_ms: self.config.window_ms,
record: true,
client_id,
};
self.do_request(client_id, "/rate-limit/check", Some(&request))
.await
}
pub async fn peek(&self, client_id: &str) -> Result<RateLimitResult, RateLimitError> {
#[derive(Serialize)]
struct CheckRequest<'a> {
limit: u64,
window_ms: u64,
record: bool,
client_id: &'a str,
}
let request = CheckRequest {
limit: self.config.limit,
window_ms: self.config.window_ms,
record: false,
client_id,
};
self.do_request(client_id, "/rate-limit/check", Some(&request))
.await
}
pub async fn reset(&self, client_id: &str) -> Result<(), RateLimitError> {
self.do_request::<()>(client_id, "/rate-limit/reset", None::<&()>)
.await
}
async fn do_request<T: for<'de> Deserialize<'de>>(
&self,
client_id: &str,
path: &str,
body: Option<&impl Serialize>,
) -> Result<T, RateLimitError> {
let env = self.env.as_ref().ok_or(RateLimitError::NoEnvironment)?;
let ns = env
.durable_object(&self.namespace)
.map_err(RateLimitError::Worker)?;
let id = ns.id_from_name(client_id).map_err(RateLimitError::Worker)?;
let stub = id.get_stub().map_err(RateLimitError::Worker)?;
let mut init = worker::RequestInit::new();
init.with_method(worker::Method::Post);
if let Some(body) = body {
let json = serde_json::to_string(body).map_err(RateLimitError::Serialization)?;
init.with_body(Some(json.into()));
}
let url = format!("https://do-internal{path}");
let request =
worker::Request::new_with_init(&url, &init).map_err(RateLimitError::Worker)?;
let mut response = stub
.fetch_with_request(request)
.await
.map_err(RateLimitError::Worker)?;
let text = response.text().await.map_err(RateLimitError::Worker)?;
serde_json::from_str(&text).map_err(RateLimitError::Deserialization)
}
}
#[derive(Debug)]
pub enum RateLimitError {
NoEnvironment,
Worker(worker::Error),
Serialization(serde_json::Error),
Deserialization(serde_json::Error),
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoEnvironment => write!(f, "No environment set"),
Self::Worker(e) => write!(f, "Worker error: {e:?}"),
Self::Serialization(e) => write!(f, "Serialization error: {e}"),
Self::Deserialization(e) => write!(f, "Deserialization error: {e}"),
}
}
}
impl std::error::Error for RateLimitError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Worker(e) => Some(e),
Self::Serialization(e) => Some(e),
Self::Deserialization(e) => Some(e),
Self::NoEnvironment => None,
}
}
}
impl From<worker::Error> for RateLimitError {
fn from(e: worker::Error) -> Self {
Self::Worker(e)
}
}
#[allow(dead_code)]
pub mod protocol {
use super::*;
#[derive(Debug, Serialize, Deserialize)]
pub struct CheckRequest {
pub limit: u64,
pub window_ms: u64,
pub record: bool,
pub client_id: String,
}
pub type CheckResponse = RateLimitResult;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = DurableObjectRateLimiter::new("MCP_RATE_LIMIT")
.with_limit(50)
.with_window_ms(30_000);
assert_eq!(limiter.namespace, "MCP_RATE_LIMIT");
assert_eq!(limiter.config.limit, 50);
assert_eq!(limiter.config.window_ms, 30_000);
}
#[test]
fn test_rate_limit_config_presets() {
let per_second = RateLimitConfig::per_second(10);
assert_eq!(per_second.limit, 10);
assert_eq!(per_second.window_ms, 1_000);
let per_minute = RateLimitConfig::per_minute(100);
assert_eq!(per_minute.limit, 100);
assert_eq!(per_minute.window_ms, 60_000);
let per_hour = RateLimitConfig::per_hour(1000);
assert_eq!(per_hour.limit, 1000);
assert_eq!(per_hour.window_ms, 3_600_000);
}
#[test]
fn test_rate_limit_error_display() {
let err = RateLimitError::NoEnvironment;
assert_eq!(err.to_string(), "No environment set");
}
}