1use axum::{
4 extract::{ConnectInfo, Request, State},
5 http::{header, HeaderValue, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use std::collections::HashSet;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15pub async fn security_headers(request: Request, next: Next) -> Response {
29 let mut response = next.run(request).await;
30 let headers = response.headers_mut();
31
32 headers.insert(
34 header::X_CONTENT_TYPE_OPTIONS,
35 HeaderValue::from_static("nosniff"),
36 );
37
38 headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
40
41 headers.insert(
43 "X-XSS-Protection",
44 HeaderValue::from_static("1; mode=block"),
45 );
46
47 headers.insert(
49 header::CONTENT_SECURITY_POLICY,
50 HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
51 );
52
53 headers.insert(
55 header::CACHE_CONTROL,
56 HeaderValue::from_static("no-store, max-age=0"),
57 );
58
59 headers.insert(
61 header::REFERRER_POLICY,
62 HeaderValue::from_static("strict-origin-when-cross-origin"),
63 );
64
65 headers.insert(
67 "Permissions-Policy",
68 HeaderValue::from_static("geolocation=(), camera=(), microphone=()"),
69 );
70
71 response
72}
73
74#[derive(Clone, Debug)]
80pub struct AuthConfig {
81 pub api_keys: HashSet<String>,
83 pub header_name: String,
85 pub prefix: String,
87 pub public_paths: HashSet<String>,
89}
90
91impl Default for AuthConfig {
92 fn default() -> Self {
93 Self {
94 api_keys: HashSet::new(),
95 header_name: "Authorization".to_string(),
96 prefix: "Bearer ".to_string(),
97 public_paths: ["/health", "/ready"]
98 .iter()
99 .map(|s| s.to_string())
100 .collect(),
101 }
102 }
103}
104
105impl AuthConfig {
106 pub fn with_keys(keys: impl IntoIterator<Item = String>) -> Self {
108 Self {
109 api_keys: keys.into_iter().collect(),
110 ..Default::default()
111 }
112 }
113
114 pub fn add_public_path(mut self, path: impl Into<String>) -> Self {
116 self.public_paths.insert(path.into());
117 self
118 }
119
120 pub fn requires_auth(&self, path: &str) -> bool {
122 !self.public_paths.contains(path)
123 }
124
125 pub fn validate_key(&self, key: &str) -> bool {
127 self.api_keys.contains(key)
128 }
129
130 pub fn is_enabled(&self) -> bool {
132 !self.api_keys.is_empty()
133 }
134}
135
136#[derive(Clone)]
138pub struct AuthState {
139 pub config: Arc<AuthConfig>,
141}
142
143pub async fn api_key_auth(
145 State(auth): State<AuthState>,
146 request: Request,
147 next: Next,
148) -> Result<Response, Response> {
149 let path = request.uri().path();
150
151 if !auth.config.requires_auth(path) {
153 return Ok(next.run(request).await);
154 }
155
156 if !auth.config.is_enabled() {
158 return Ok(next.run(request).await);
159 }
160
161 let auth_header = request
163 .headers()
164 .get(&auth.config.header_name)
165 .and_then(|v| v.to_str().ok());
166
167 let api_key = match auth_header {
168 Some(value) if value.starts_with(&auth.config.prefix) => &value[auth.config.prefix.len()..],
169 Some(_) => {
170 return Err((
171 StatusCode::UNAUTHORIZED,
172 [(header::WWW_AUTHENTICATE, "Bearer")],
173 "Invalid authorization header format",
174 )
175 .into_response());
176 }
177 None => {
178 return Err((
179 StatusCode::UNAUTHORIZED,
180 [(header::WWW_AUTHENTICATE, "Bearer")],
181 "Missing authorization header",
182 )
183 .into_response());
184 }
185 };
186
187 if !auth.config.validate_key(api_key) {
189 tracing::warn!(
190 path = %path,
191 "Invalid API key attempt"
192 );
193 return Err((
194 StatusCode::UNAUTHORIZED,
195 [(header::WWW_AUTHENTICATE, "Bearer")],
196 "Invalid API key",
197 )
198 .into_response());
199 }
200
201 Ok(next.run(request).await)
202}
203
204#[derive(Clone, Debug, Default)]
210pub enum RateLimitKey {
211 #[default]
213 ByIp,
214 ByApiKey,
216}
217
218#[derive(Clone, Debug)]
220pub struct RateLimitConfig {
221 pub max_requests: u32,
223 pub window: Duration,
225 pub key_strategy: RateLimitKey,
227}
228
229impl Default for RateLimitConfig {
230 fn default() -> Self {
231 Self {
232 max_requests: 100,
233 window: Duration::from_secs(60),
234 key_strategy: RateLimitKey::ByIp,
235 }
236 }
237}
238
239#[derive(Clone)]
241struct RateLimitEntry {
242 count: u32,
243 window_start: Instant,
244}
245
246#[derive(Clone)]
248pub struct RateLimitState {
249 pub config: Arc<RateLimitConfig>,
251 entries: Arc<RwLock<std::collections::HashMap<String, RateLimitEntry>>>,
252}
253
254impl RateLimitState {
255 pub fn new(config: RateLimitConfig) -> Self {
257 Self {
258 config: Arc::new(config),
259 entries: Arc::new(RwLock::new(std::collections::HashMap::new())),
260 }
261 }
262
263 pub async fn cleanup(&self) {
265 let now = Instant::now();
266 let window = self.config.window;
267 let mut entries = self.entries.write().await;
268 entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
269 }
270
271 async fn check_and_increment(&self, key: String) -> Result<(u32, u32), (u32, Duration)> {
273 let now = Instant::now();
274 let mut entries = self.entries.write().await;
275
276 let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
277 count: 0,
278 window_start: now,
279 });
280
281 if now.duration_since(entry.window_start) >= self.config.window {
283 entry.count = 0;
284 entry.window_start = now;
285 }
286
287 entry.count += 1;
288
289 if entry.count > self.config.max_requests {
290 let retry_after = self.config.window - now.duration_since(entry.window_start);
291 Err((entry.count, retry_after))
292 } else {
293 Ok((
294 self.config.max_requests - entry.count,
295 self.config.max_requests,
296 ))
297 }
298 }
299}
300
301pub async fn rate_limit(
303 State(state): State<RateLimitState>,
304 ConnectInfo(addr): ConnectInfo<SocketAddr>,
305 request: Request,
306 next: Next,
307) -> Result<Response, Response> {
308 let key = match state.config.key_strategy {
309 RateLimitKey::ByIp => addr.ip().to_string(),
310 RateLimitKey::ByApiKey => {
311 request
313 .headers()
314 .get(header::AUTHORIZATION)
315 .and_then(|v| v.to_str().ok())
316 .map(|s| s.trim_start_matches("Bearer ").to_string())
317 .unwrap_or_else(|| addr.ip().to_string())
318 }
319 };
320
321 match state.check_and_increment(key).await {
322 Ok((remaining, limit)) => {
323 let mut response = next.run(request).await;
324 let headers = response.headers_mut();
325
326 headers.insert(
328 "X-RateLimit-Limit",
329 HeaderValue::from_str(&limit.to_string()).unwrap(),
330 );
331 headers.insert(
332 "X-RateLimit-Remaining",
333 HeaderValue::from_str(&remaining.to_string()).unwrap(),
334 );
335
336 Ok(response)
337 }
338 Err((_, retry_after)) => {
339 let retry_secs = retry_after.as_secs().max(1);
340 Err((
341 StatusCode::TOO_MANY_REQUESTS,
342 [
343 ("Retry-After", retry_secs.to_string()),
344 ("X-RateLimit-Limit", state.config.max_requests.to_string()),
345 ("X-RateLimit-Remaining", "0".to_string()),
346 ],
347 "Rate limit exceeded",
348 )
349 .into_response())
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_auth_config_default() {
360 let config = AuthConfig::default();
361 assert!(!config.is_enabled());
362 assert!(config.public_paths.contains("/health"));
363 assert!(config.public_paths.contains("/ready"));
364 }
365
366 #[test]
367 fn test_auth_config_with_keys() {
368 let config = AuthConfig::with_keys(["key1".to_string(), "key2".to_string()]);
369 assert!(config.is_enabled());
370 assert!(config.validate_key("key1"));
371 assert!(config.validate_key("key2"));
372 assert!(!config.validate_key("key3"));
373 }
374
375 #[test]
376 fn test_auth_config_public_paths() {
377 let config = AuthConfig::default().add_public_path("/metrics");
378 assert!(!config.requires_auth("/health"));
379 assert!(!config.requires_auth("/ready"));
380 assert!(!config.requires_auth("/metrics"));
381 assert!(config.requires_auth("/v1/state"));
382 }
383
384 #[test]
385 fn test_rate_limit_config_default() {
386 let config = RateLimitConfig::default();
387 assert_eq!(config.max_requests, 100);
388 assert_eq!(config.window, Duration::from_secs(60));
389 }
390
391 #[tokio::test]
392 async fn test_rate_limit_state() {
393 let state = RateLimitState::new(RateLimitConfig {
394 max_requests: 3,
395 window: Duration::from_secs(60),
396 key_strategy: RateLimitKey::ByIp,
397 });
398
399 assert!(state.check_and_increment("test".to_string()).await.is_ok());
401 assert!(state.check_and_increment("test".to_string()).await.is_ok());
402 assert!(state.check_and_increment("test".to_string()).await.is_ok());
403
404 assert!(state.check_and_increment("test".to_string()).await.is_err());
406
407 assert!(state.check_and_increment("other".to_string()).await.is_ok());
409 }
410}