oxidite_middleware/
rate_limit.rs1use oxidite_db::Database;
2use std::sync::Arc;
3use std::collections::HashMap;
4use tokio::sync::Mutex;
5
6#[derive(Clone, Debug)]
8pub struct RateLimitConfig {
9 pub requests_per_minute: u32,
10 pub requests_per_hour: Option<u32>,
11}
12
13impl Default for RateLimitConfig {
14 fn default() -> Self {
15 Self {
16 requests_per_minute: 60,
17 requests_per_hour: Some(1000),
18 }
19 }
20}
21
22pub struct RateLimiter {
24 db: Option<Arc<dyn Database>>,
25 config: RateLimitConfig,
26 cache: Arc<Mutex<HashMap<String, Vec<i64>>>>,
28}
29
30impl RateLimiter {
31 pub fn new(config: RateLimitConfig) -> Self {
32 Self {
33 db: None,
34 config,
35 cache: Arc::new(Mutex::new(HashMap::new())),
36 }
37 }
38
39 pub fn with_db(config: RateLimitConfig, db: Arc<dyn Database>) -> Self {
40 Self {
41 db: Some(db),
42 config,
43 cache: Arc::new(Mutex::new(HashMap::new())),
44 }
45 }
46
47 pub async fn check(&self, identifier: &str, endpoint: &str) -> bool {
49 let now = chrono::Utc::now().timestamp();
50 let minute_ago = now - 60;
51 let hour_ago = now - 3600;
52
53 let mut cache = self.cache.lock().await;
55 let key = format!("{}:{}", identifier, endpoint);
56
57 let timestamps = cache.entry(key.clone()).or_insert_with(Vec::new);
59
60 timestamps.retain(|&ts| ts > hour_ago);
62
63 let minute_count = timestamps.iter().filter(|&&ts| ts > minute_ago).count() as u32;
65
66 if minute_count >= self.config.requests_per_minute {
68 return false;
69 }
70
71 if let Some(hour_limit) = self.config.requests_per_hour {
73 let hour_count = timestamps.len() as u32;
74 if hour_count >= hour_limit {
75 return false;
76 }
77 }
78
79 timestamps.push(now);
81
82 if let Some(db) = &self.db {
84 let db_clone = db.clone();
85 let ident = identifier.to_string();
86 let ep = endpoint.to_string();
87 tokio::spawn(async move {
88 let _ = Self::record_request(&*db_clone, &ident, &ep).await;
89 });
90 }
91
92 true
93 }
94
95 async fn record_request(db: &dyn Database, identifier: &str, endpoint: &str) -> oxidite_db::Result<()> {
97 let now = chrono::Utc::now().timestamp();
98 let window_start = (now / 60) * 60; let update_query = format!(
102 "UPDATE rate_limits
103 SET request_count = request_count + 1, updated_at = {}
104 WHERE identifier = '{}' AND endpoint = '{}' AND window_start = {}",
105 now, identifier, endpoint, window_start
106 );
107
108 let rows = db.execute(&update_query).await?;
109
110 if rows == 0 {
112 let insert_query = format!(
113 "INSERT INTO rate_limits (identifier, endpoint, request_count, window_start, created_at, updated_at)
114 VALUES ('{}', '{}', 1, {}, {}, {})",
115 identifier, endpoint, window_start, now, now
116 );
117 db.execute(&insert_query).await?;
118 }
119
120 Ok(())
121 }
122
123 pub async fn get_remaining(&self, identifier: &str, endpoint: &str) -> u32 {
125 let now = chrono::Utc::now().timestamp();
126 let minute_ago = now - 60;
127
128 let cache = self.cache.lock().await;
129 let key = format!("{}:{}", identifier, endpoint);
130
131 if let Some(timestamps) = cache.get(&key) {
132 let minute_count = timestamps.iter().filter(|&&ts| ts > minute_ago).count() as u32;
133 self.config.requests_per_minute.saturating_sub(minute_count)
134 } else {
135 self.config.requests_per_minute
136 }
137 }
138
139 pub async fn cleanup(&self) {
141 let now = chrono::Utc::now().timestamp();
142 let hour_ago = now - 3600;
143
144 let mut cache = self.cache.lock().await;
145 cache.retain(|_, timestamps| {
146 timestamps.retain(|&ts| ts > hour_ago);
147 !timestamps.is_empty()
148 });
149 }
150}