Skip to main content

detritus_server/
rate_limit.rs

1use std::{collections::HashMap, sync::Arc, time::Instant};
2
3use serde::Deserialize;
4use tokio::sync::Mutex;
5
6use crate::{auth::TokenContext, storage::SourceKey};
7
8/// Per-token rate limits for log and crash ingestion.
9#[derive(Debug, Clone, Copy, Deserialize)]
10pub struct RateLimitConfig {
11    /// Sustained log export batches permitted per minute.
12    #[serde(default = "default_logs_per_minute")]
13    pub logs_per_minute: u32,
14    /// Burst capacity for log export batches.
15    #[serde(default = "default_logs_burst")]
16    pub logs_burst: u32,
17    /// Sustained crash uploads permitted per minute.
18    #[serde(default = "default_crashes_per_minute")]
19    pub crashes_per_minute: u32,
20    /// Burst capacity for crash uploads.
21    #[serde(default = "default_crashes_burst")]
22    pub crashes_burst: u32,
23}
24
25impl Default for RateLimitConfig {
26    fn default() -> Self {
27        Self {
28            logs_per_minute: default_logs_per_minute(),
29            logs_burst: default_logs_burst(),
30            crashes_per_minute: default_crashes_per_minute(),
31            crashes_burst: default_crashes_burst(),
32        }
33    }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct RateLimiter {
38    buckets: Arc<Mutex<HashMap<RateKey, Bucket>>>,
39    config: RateLimitConfig,
40}
41
42impl RateLimiter {
43    pub(crate) fn new(config: RateLimitConfig) -> Self {
44        Self {
45            buckets: Arc::new(Mutex::new(HashMap::new())),
46            config,
47        }
48    }
49
50    pub(crate) async fn check_logs(
51        &self,
52        token: &TokenContext,
53        source: &SourceKey,
54    ) -> Result<(), RateLimitError> {
55        self.check(
56            "logs",
57            token,
58            source,
59            self.config.logs_per_minute,
60            self.config.logs_burst,
61        )
62        .await
63    }
64
65    pub(crate) async fn check_crashes(
66        &self,
67        token: &TokenContext,
68        source: &SourceKey,
69    ) -> Result<(), RateLimitError> {
70        self.check(
71            "crashes",
72            token,
73            source,
74            self.config.crashes_per_minute,
75            self.config.crashes_burst,
76        )
77        .await
78    }
79
80    async fn check(
81        &self,
82        endpoint: &'static str,
83        token: &TokenContext,
84        source: &SourceKey,
85        per_minute: u32,
86        burst: u32,
87    ) -> Result<(), RateLimitError> {
88        let mut buckets = self.buckets.lock().await;
89        let key = RateKey {
90            endpoint,
91            token_id: token.id.clone(),
92            source: source.canonical(),
93        };
94        let now = Instant::now();
95        let bucket = buckets.entry(key).or_insert_with(|| Bucket {
96            tokens: f64::from(burst),
97            last_refill: now,
98        });
99        let refill_per_second = f64::from(per_minute) / 60.0;
100        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
101        bucket.tokens = (bucket.tokens + elapsed * refill_per_second).min(f64::from(burst));
102        bucket.last_refill = now;
103        if bucket.tokens >= 1.0 {
104            bucket.tokens -= 1.0;
105            Ok(())
106        } else {
107            Err(RateLimitError)
108        }
109    }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113struct RateKey {
114    endpoint: &'static str,
115    token_id: String,
116    source: String,
117}
118
119#[derive(Debug, Clone)]
120struct Bucket {
121    tokens: f64,
122    last_refill: Instant,
123}
124
125#[derive(Debug, Clone, Copy, thiserror::Error)]
126#[error("rate limit exceeded")]
127pub(crate) struct RateLimitError;
128
129fn default_logs_per_minute() -> u32 {
130    1_000
131}
132
133fn default_logs_burst() -> u32 {
134    200
135}
136
137fn default_crashes_per_minute() -> u32 {
138    30
139}
140
141fn default_crashes_burst() -> u32 {
142    5
143}