use super::error::{OrganizationError, Result};
use governor::{
Quota, RateLimiter, clock::DefaultClock, middleware::NoOpMiddleware,
state::keyed::DashMapStateStore,
};
use std::{
num::NonZeroU32,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
const SHRINK_INTERVAL: u64 = 1000;
#[derive(Clone, Debug)]
pub struct InvitationRateLimitConfig {
pub max_per_org: u32,
pub max_per_actor: u32,
pub window_seconds: u64,
}
impl Default for InvitationRateLimitConfig {
fn default() -> Self {
Self {
max_per_org: 50,
max_per_actor: 20,
window_seconds: 3600, }
}
}
impl InvitationRateLimitConfig {
pub fn new(max_per_org: u32, max_per_actor: u32, window_seconds: u64) -> Self {
Self {
max_per_org,
max_per_actor,
window_seconds,
}
}
pub fn strict() -> Self {
Self {
max_per_org: 10,
max_per_actor: 5,
window_seconds: 3600,
}
}
pub fn lenient() -> Self {
Self {
max_per_org: 200,
max_per_actor: 50,
window_seconds: 3600,
}
}
}
type KeyedLimiter = RateLimiter<String, DashMapStateStore<String>, DefaultClock, NoOpMiddleware>;
#[derive(Clone)]
pub struct InvitationRateLimiter {
org_limiter: Arc<KeyedLimiter>,
actor_limiter: Arc<KeyedLimiter>,
config: InvitationRateLimitConfig,
request_count: Arc<AtomicU64>,
}
impl InvitationRateLimiter {
pub fn new(config: InvitationRateLimitConfig) -> Self {
let max_per_org = NonZeroU32::new(config.max_per_org.max(1)).unwrap_or(NonZeroU32::MIN);
let max_per_actor = NonZeroU32::new(config.max_per_actor.max(1)).unwrap_or(NonZeroU32::MIN);
let window = Duration::from_secs(config.window_seconds.max(1));
let org_quota = Quota::with_period(window)
.unwrap_or_else(|| Quota::per_second(max_per_org))
.allow_burst(max_per_org);
let actor_quota = Quota::with_period(window)
.unwrap_or_else(|| Quota::per_second(max_per_actor))
.allow_burst(max_per_actor);
Self {
org_limiter: Arc::new(RateLimiter::keyed(org_quota)),
actor_limiter: Arc::new(RateLimiter::keyed(actor_quota)),
config,
request_count: Arc::new(AtomicU64::new(0)),
}
}
pub fn check(&self, org_id: &str, actor_id: &str) -> std::result::Result<(), (String, u64)> {
let count = self.request_count.fetch_add(1, Ordering::Relaxed);
if count % SHRINK_INTERVAL == 0 && count > 0 {
self.org_limiter.retain_recent();
self.actor_limiter.retain_recent();
}
if let Err(not_until) = self.org_limiter.check_key(&org_id.to_string()) {
let wait =
not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
return Err(("organization".to_string(), wait.as_secs().max(1)));
}
if let Err(not_until) = self.actor_limiter.check_key(&actor_id.to_string()) {
let wait =
not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
return Err(("actor".to_string(), wait.as_secs().max(1)));
}
Ok(())
}
pub fn config(&self) -> &InvitationRateLimitConfig {
&self.config
}
}
pub trait OptionalInvitationRateLimiter: Send + Sync + Clone + 'static {
fn check_invitation_rate(&self, org_id: &str, actor_id: &str) -> Result<()>;
}
impl OptionalInvitationRateLimiter for () {
fn check_invitation_rate(&self, _org_id: &str, _actor_id: &str) -> Result<()> {
Ok(())
}
}
#[derive(Clone)]
pub struct WithInvitationRateLimiter(pub InvitationRateLimiter);
impl OptionalInvitationRateLimiter for WithInvitationRateLimiter {
fn check_invitation_rate(&self, org_id: &str, actor_id: &str) -> Result<()> {
match self.0.check(org_id, actor_id) {
Ok(()) => Ok(()),
Err((limit_type, retry_after)) => {
tracing::warn!(
target: "orgs.invitation.rate_limited",
org_id = %org_id,
actor_id = %actor_id,
limit_type = %limit_type,
retry_after_secs = retry_after,
max_per_org = self.0.config.max_per_org,
max_per_actor = self.0.config.max_per_actor,
window_secs = self.0.config.window_seconds,
"Invitation rate limited"
);
Err(OrganizationError::RateLimited {
retry_after_seconds: retry_after,
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_allows_requests_under_limit() {
let config = InvitationRateLimitConfig::new(5, 3, 60);
let limiter = InvitationRateLimiter::new(config);
for i in 0..3 {
let result = limiter.check("org_1", "actor_1");
assert!(result.is_ok(), "Request {} should be allowed", i + 1);
}
}
#[test]
fn test_rate_limit_blocks_requests_over_actor_limit() {
let config = InvitationRateLimitConfig::new(10, 3, 60);
let limiter = InvitationRateLimiter::new(config);
for _ in 0..3 {
limiter.check("org_1", "actor_1").unwrap();
}
let result = limiter.check("org_1", "actor_1");
assert!(result.is_err(), "4th request should be blocked");
if let Err((limit_type, _)) = result {
assert_eq!(limit_type, "actor");
}
}
#[test]
fn test_rate_limit_blocks_requests_over_org_limit() {
let config = InvitationRateLimitConfig::new(3, 10, 60);
let limiter = InvitationRateLimiter::new(config);
for i in 0..3 {
limiter.check("org_1", &format!("actor_{i}")).unwrap();
}
let result = limiter.check("org_1", "actor_new");
assert!(result.is_err(), "4th request should be blocked");
if let Err((limit_type, _)) = result {
assert_eq!(limit_type, "organization");
}
}
#[test]
fn test_rate_limit_per_org_isolation() {
let config = InvitationRateLimitConfig::new(3, 10, 60);
let limiter = InvitationRateLimiter::new(config);
for _ in 0..3 {
limiter.check("org_1", "actor_1").unwrap();
}
let result = limiter.check("org_2", "actor_1");
assert!(result.is_ok(), "Different org should have separate quota");
}
#[test]
fn test_rate_limit_per_actor_isolation() {
let config = InvitationRateLimitConfig::new(10, 3, 60);
let limiter = InvitationRateLimiter::new(config);
for _ in 0..3 {
limiter.check("org_1", "actor_1").unwrap();
}
let result = limiter.check("org_1", "actor_2");
assert!(result.is_ok(), "Different actor should have separate quota");
}
#[test]
fn test_optional_rate_limiter_noop() {
let noop: () = ();
assert!(noop.check_invitation_rate("org_1", "actor_1").is_ok());
}
#[test]
fn test_optional_rate_limiter_with_limiter() {
let config = InvitationRateLimitConfig::new(10, 2, 60);
let limiter = WithInvitationRateLimiter(InvitationRateLimiter::new(config));
assert!(limiter.check_invitation_rate("org_1", "actor_1").is_ok());
assert!(limiter.check_invitation_rate("org_1", "actor_1").is_ok());
assert!(limiter.check_invitation_rate("org_1", "actor_1").is_err());
}
#[test]
fn test_config_presets() {
let default = InvitationRateLimitConfig::default();
assert_eq!(default.max_per_org, 50);
assert_eq!(default.max_per_actor, 20);
assert_eq!(default.window_seconds, 3600);
let strict = InvitationRateLimitConfig::strict();
assert_eq!(strict.max_per_org, 10);
assert_eq!(strict.max_per_actor, 5);
assert_eq!(strict.window_seconds, 3600);
let lenient = InvitationRateLimitConfig::lenient();
assert_eq!(lenient.max_per_org, 200);
assert_eq!(lenient.max_per_actor, 50);
assert_eq!(lenient.window_seconds, 3600);
}
}