enigma_node_registry/
rate_limit.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use tokio::sync::Mutex;
6
7use crate::config::RateLimitConfig;
8use crate::error::{RegistryError, RegistryResult};
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
11pub enum RateScope {
12 Global,
13 Register,
14 Resolve,
15 CheckUser,
16}
17
18struct Bucket {
19 tokens: f64,
20 last: Instant,
21 rate: f64,
22 burst: f64,
23}
24
25impl Bucket {
26 fn new(rate: f64, burst: f64) -> Self {
27 Bucket {
28 tokens: burst,
29 last: Instant::now(),
30 rate,
31 burst,
32 }
33 }
34
35 fn consume(&mut self, now: Instant) -> bool {
36 let elapsed = now.saturating_duration_since(self.last);
37 let refill = elapsed.as_secs_f64() * self.rate;
38 self.tokens = (self.tokens + refill).min(self.burst);
39 self.last = now;
40 if self.tokens >= 1.0 {
41 self.tokens -= 1.0;
42 true
43 } else {
44 false
45 }
46 }
47}
48
49struct IpEntry {
50 banned_until: Option<Instant>,
51 buckets: HashMap<RateScope, Bucket>,
52}
53
54impl IpEntry {
55 fn new(cfg: &RateLimitConfig) -> Self {
56 let mut buckets = HashMap::new();
57 buckets.insert(
58 RateScope::Global,
59 Bucket::new(cfg.per_ip_rps as f64, cfg.burst as f64),
60 );
61 buckets.insert(
62 RateScope::Register,
63 Bucket::new(cfg.endpoints.register_rps as f64, cfg.burst as f64),
64 );
65 buckets.insert(
66 RateScope::Resolve,
67 Bucket::new(cfg.endpoints.resolve_rps as f64, cfg.burst as f64),
68 );
69 buckets.insert(
70 RateScope::CheckUser,
71 Bucket::new(cfg.endpoints.check_user_rps as f64, cfg.burst as f64),
72 );
73 IpEntry {
74 banned_until: None,
75 buckets,
76 }
77 }
78}
79
80#[derive(Clone)]
81pub struct RateLimiter {
82 cfg: RateLimitConfig,
83 inner: Arc<Mutex<HashMap<String, IpEntry>>>,
84}
85
86impl RateLimiter {
87 pub fn new(cfg: RateLimitConfig) -> Self {
88 RateLimiter {
89 cfg,
90 inner: Arc::new(Mutex::new(HashMap::new())),
91 }
92 }
93
94 pub async fn check(&self, ip: &str, scope: RateScope) -> RegistryResult<()> {
95 if !self.cfg.enabled {
96 return Ok(());
97 }
98 let now = Instant::now();
99 let mut guard = self.inner.lock().await;
100 let entry = guard
101 .entry(ip.to_string())
102 .or_insert_with(|| IpEntry::new(&self.cfg));
103 if let Some(until) = entry.banned_until {
104 if until > now {
105 return Err(RegistryError::RateLimited);
106 }
107 }
108 if !self.consume(entry, RateScope::Global, now) || !self.consume(entry, scope, now) {
109 entry.banned_until = Some(now + Duration::from_secs(self.cfg.ban_seconds));
110 return Err(RegistryError::RateLimited);
111 }
112 Ok(())
113 }
114
115 fn consume(&self, entry: &mut IpEntry, scope: RateScope, now: Instant) -> bool {
116 if let Some(bucket) = entry.buckets.get_mut(&scope) {
117 bucket.consume(now)
118 } else {
119 false
120 }
121 }
122}