enigma_node_registry/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use tokio::sync::Mutex;
6
7use crate::config::RateLimitConfig;
8use crate::error::{RegistryError, RegistryResult};
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
11pub enum RateScope {
12    Global,
13    Register,
14    Resolve,
15    CheckUser,
16}
17
18struct Bucket {
19    tokens: f64,
20    last: Instant,
21    rate: f64,
22    burst: f64,
23}
24
25impl Bucket {
26    fn new(rate: f64, burst: f64) -> Self {
27        Bucket {
28            tokens: burst,
29            last: Instant::now(),
30            rate,
31            burst,
32        }
33    }
34
35    fn consume(&mut self, now: Instant) -> bool {
36        let elapsed = now.saturating_duration_since(self.last);
37        let refill = elapsed.as_secs_f64() * self.rate;
38        self.tokens = (self.tokens + refill).min(self.burst);
39        self.last = now;
40        if self.tokens >= 1.0 {
41            self.tokens -= 1.0;
42            true
43        } else {
44            false
45        }
46    }
47}
48
49struct IpEntry {
50    banned_until: Option<Instant>,
51    buckets: HashMap<RateScope, Bucket>,
52}
53
54impl IpEntry {
55    fn new(cfg: &RateLimitConfig) -> Self {
56        let mut buckets = HashMap::new();
57        buckets.insert(
58            RateScope::Global,
59            Bucket::new(cfg.per_ip_rps as f64, cfg.burst as f64),
60        );
61        buckets.insert(
62            RateScope::Register,
63            Bucket::new(cfg.endpoints.register_rps as f64, cfg.burst as f64),
64        );
65        buckets.insert(
66            RateScope::Resolve,
67            Bucket::new(cfg.endpoints.resolve_rps as f64, cfg.burst as f64),
68        );
69        buckets.insert(
70            RateScope::CheckUser,
71            Bucket::new(cfg.endpoints.check_user_rps as f64, cfg.burst as f64),
72        );
73        IpEntry {
74            banned_until: None,
75            buckets,
76        }
77    }
78}
79
80#[derive(Clone)]
81pub struct RateLimiter {
82    cfg: RateLimitConfig,
83    inner: Arc<Mutex<HashMap<String, IpEntry>>>,
84}
85
86impl RateLimiter {
87    pub fn new(cfg: RateLimitConfig) -> Self {
88        RateLimiter {
89            cfg,
90            inner: Arc::new(Mutex::new(HashMap::new())),
91        }
92    }
93
94    pub async fn check(&self, ip: &str, scope: RateScope) -> RegistryResult<()> {
95        if !self.cfg.enabled {
96            return Ok(());
97        }
98        let now = Instant::now();
99        let mut guard = self.inner.lock().await;
100        let entry = guard
101            .entry(ip.to_string())
102            .or_insert_with(|| IpEntry::new(&self.cfg));
103        if let Some(until) = entry.banned_until {
104            if until > now {
105                return Err(RegistryError::RateLimited);
106            }
107        }
108        if !self.consume(entry, RateScope::Global, now) || !self.consume(entry, scope, now) {
109            entry.banned_until = Some(now + Duration::from_secs(self.cfg.ban_seconds));
110            return Err(RegistryError::RateLimited);
111        }
112        Ok(())
113    }
114
115    fn consume(&self, entry: &mut IpEntry, scope: RateScope, now: Instant) -> bool {
116        if let Some(bucket) = entry.buckets.get_mut(&scope) {
117            bucket.consume(now)
118        } else {
119            false
120        }
121    }
122}