detritus_server/
rate_limit.rs1use std::{collections::HashMap, sync::Arc, time::Instant};
2
3use serde::Deserialize;
4use tokio::sync::Mutex;
5
6use crate::{auth::TokenContext, storage::SourceKey};
7
8#[derive(Debug, Clone, Copy, Deserialize)]
10pub struct RateLimitConfig {
11 #[serde(default = "default_logs_per_minute")]
13 pub logs_per_minute: u32,
14 #[serde(default = "default_logs_burst")]
16 pub logs_burst: u32,
17 #[serde(default = "default_crashes_per_minute")]
19 pub crashes_per_minute: u32,
20 #[serde(default = "default_crashes_burst")]
22 pub crashes_burst: u32,
23}
24
25impl Default for RateLimitConfig {
26 fn default() -> Self {
27 Self {
28 logs_per_minute: default_logs_per_minute(),
29 logs_burst: default_logs_burst(),
30 crashes_per_minute: default_crashes_per_minute(),
31 crashes_burst: default_crashes_burst(),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct RateLimiter {
38 buckets: Arc<Mutex<HashMap<RateKey, Bucket>>>,
39 config: RateLimitConfig,
40}
41
42impl RateLimiter {
43 pub(crate) fn new(config: RateLimitConfig) -> Self {
44 Self {
45 buckets: Arc::new(Mutex::new(HashMap::new())),
46 config,
47 }
48 }
49
50 pub(crate) async fn check_logs(
51 &self,
52 token: &TokenContext,
53 source: &SourceKey,
54 ) -> Result<(), RateLimitError> {
55 self.check(
56 "logs",
57 token,
58 source,
59 self.config.logs_per_minute,
60 self.config.logs_burst,
61 )
62 .await
63 }
64
65 pub(crate) async fn check_crashes(
66 &self,
67 token: &TokenContext,
68 source: &SourceKey,
69 ) -> Result<(), RateLimitError> {
70 self.check(
71 "crashes",
72 token,
73 source,
74 self.config.crashes_per_minute,
75 self.config.crashes_burst,
76 )
77 .await
78 }
79
80 async fn check(
81 &self,
82 endpoint: &'static str,
83 token: &TokenContext,
84 source: &SourceKey,
85 per_minute: u32,
86 burst: u32,
87 ) -> Result<(), RateLimitError> {
88 let mut buckets = self.buckets.lock().await;
89 let key = RateKey {
90 endpoint,
91 token_id: token.id.clone(),
92 source: source.canonical(),
93 };
94 let now = Instant::now();
95 let bucket = buckets.entry(key).or_insert_with(|| Bucket {
96 tokens: f64::from(burst),
97 last_refill: now,
98 });
99 let refill_per_second = f64::from(per_minute) / 60.0;
100 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
101 bucket.tokens = (bucket.tokens + elapsed * refill_per_second).min(f64::from(burst));
102 bucket.last_refill = now;
103 if bucket.tokens >= 1.0 {
104 bucket.tokens -= 1.0;
105 Ok(())
106 } else {
107 Err(RateLimitError)
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113struct RateKey {
114 endpoint: &'static str,
115 token_id: String,
116 source: String,
117}
118
119#[derive(Debug, Clone)]
120struct Bucket {
121 tokens: f64,
122 last_refill: Instant,
123}
124
125#[derive(Debug, Clone, Copy, thiserror::Error)]
126#[error("rate limit exceeded")]
127pub(crate) struct RateLimitError;
128
129fn default_logs_per_minute() -> u32 {
130 1_000
131}
132
133fn default_logs_burst() -> u32 {
134 200
135}
136
137fn default_crashes_per_minute() -> u32 {
138 30
139}
140
141fn default_crashes_burst() -> u32 {
142 5
143}