attuned_http/
middleware.rs

1//! HTTP middleware for security, rate limiting, and authentication.
2
3use axum::{
4    extract::{ConnectInfo, Request, State},
5    http::{header, HeaderValue, StatusCode},
6    middleware::Next,
7    response::{IntoResponse, Response},
8};
9use std::collections::HashSet;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15// ============================================================================
16// Security Headers Middleware
17// ============================================================================
18
19/// Add security headers to all responses.
20///
21/// Headers added:
22/// - `X-Content-Type-Options: nosniff` - Prevent MIME sniffing
23/// - `X-Frame-Options: DENY` - Prevent clickjacking
24/// - `X-XSS-Protection: 1; mode=block` - Legacy XSS protection
25/// - `Content-Security-Policy: default-src 'none'` - Strict CSP
26/// - `Cache-Control: no-store` - Prevent caching of sensitive data
27/// - `Referrer-Policy: strict-origin-when-cross-origin` - Control referrer info
28pub async fn security_headers(request: Request, next: Next) -> Response {
29    let mut response = next.run(request).await;
30    let headers = response.headers_mut();
31
32    // Prevent MIME type sniffing
33    headers.insert(
34        header::X_CONTENT_TYPE_OPTIONS,
35        HeaderValue::from_static("nosniff"),
36    );
37
38    // Prevent clickjacking
39    headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
40
41    // Legacy XSS protection (still useful for older browsers)
42    headers.insert(
43        "X-XSS-Protection",
44        HeaderValue::from_static("1; mode=block"),
45    );
46
47    // Strict Content Security Policy (API-only, no inline content)
48    headers.insert(
49        header::CONTENT_SECURITY_POLICY,
50        HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
51    );
52
53    // Prevent caching of potentially sensitive responses
54    headers.insert(
55        header::CACHE_CONTROL,
56        HeaderValue::from_static("no-store, max-age=0"),
57    );
58
59    // Control referrer information
60    headers.insert(
61        header::REFERRER_POLICY,
62        HeaderValue::from_static("strict-origin-when-cross-origin"),
63    );
64
65    // Permissions policy (disable all browser features)
66    headers.insert(
67        "Permissions-Policy",
68        HeaderValue::from_static("geolocation=(), camera=(), microphone=()"),
69    );
70
71    response
72}
73
74// ============================================================================
75// API Key Authentication Middleware
76// ============================================================================
77
78/// Configuration for API key authentication.
79#[derive(Clone, Debug)]
80pub struct AuthConfig {
81    /// Set of valid API keys.
82    pub api_keys: HashSet<String>,
83    /// Header name for API key (default: "Authorization").
84    pub header_name: String,
85    /// Prefix expected before the key (default: "Bearer ").
86    pub prefix: String,
87    /// Paths that don't require authentication.
88    pub public_paths: HashSet<String>,
89}
90
91impl Default for AuthConfig {
92    fn default() -> Self {
93        Self {
94            api_keys: HashSet::new(),
95            header_name: "Authorization".to_string(),
96            prefix: "Bearer ".to_string(),
97            public_paths: ["/health", "/ready"]
98                .iter()
99                .map(|s| s.to_string())
100                .collect(),
101        }
102    }
103}
104
105impl AuthConfig {
106    /// Create a new auth config with the given API keys.
107    pub fn with_keys(keys: impl IntoIterator<Item = String>) -> Self {
108        Self {
109            api_keys: keys.into_iter().collect(),
110            ..Default::default()
111        }
112    }
113
114    /// Add a public path that doesn't require authentication.
115    pub fn add_public_path(mut self, path: impl Into<String>) -> Self {
116        self.public_paths.insert(path.into());
117        self
118    }
119
120    /// Check if authentication is required for a path.
121    pub fn requires_auth(&self, path: &str) -> bool {
122        !self.public_paths.contains(path)
123    }
124
125    /// Validate an API key.
126    pub fn validate_key(&self, key: &str) -> bool {
127        self.api_keys.contains(key)
128    }
129
130    /// Check if authentication is enabled (has any API keys configured).
131    pub fn is_enabled(&self) -> bool {
132        !self.api_keys.is_empty()
133    }
134}
135
136/// State for authentication middleware.
137#[derive(Clone)]
138pub struct AuthState {
139    /// The authentication configuration.
140    pub config: Arc<AuthConfig>,
141}
142
143/// API key authentication middleware.
144pub async fn api_key_auth(
145    State(auth): State<AuthState>,
146    request: Request,
147    next: Next,
148) -> Result<Response, Response> {
149    let path = request.uri().path();
150
151    // Skip auth for public paths
152    if !auth.config.requires_auth(path) {
153        return Ok(next.run(request).await);
154    }
155
156    // Skip auth if not enabled (no keys configured)
157    if !auth.config.is_enabled() {
158        return Ok(next.run(request).await);
159    }
160
161    // Extract API key from header
162    let auth_header = request
163        .headers()
164        .get(&auth.config.header_name)
165        .and_then(|v| v.to_str().ok());
166
167    let api_key = match auth_header {
168        Some(value) if value.starts_with(&auth.config.prefix) => &value[auth.config.prefix.len()..],
169        Some(_) => {
170            return Err((
171                StatusCode::UNAUTHORIZED,
172                [(header::WWW_AUTHENTICATE, "Bearer")],
173                "Invalid authorization header format",
174            )
175                .into_response());
176        }
177        None => {
178            return Err((
179                StatusCode::UNAUTHORIZED,
180                [(header::WWW_AUTHENTICATE, "Bearer")],
181                "Missing authorization header",
182            )
183                .into_response());
184        }
185    };
186
187    // Validate the key
188    if !auth.config.validate_key(api_key) {
189        tracing::warn!(
190            path = %path,
191            "Invalid API key attempt"
192        );
193        return Err((
194            StatusCode::UNAUTHORIZED,
195            [(header::WWW_AUTHENTICATE, "Bearer")],
196            "Invalid API key",
197        )
198            .into_response());
199    }
200
201    Ok(next.run(request).await)
202}
203
204// ============================================================================
205// Rate Limiting Middleware
206// ============================================================================
207
208/// Rate limiting strategy.
209#[derive(Clone, Debug, Default)]
210pub enum RateLimitKey {
211    /// Rate limit by client IP address.
212    #[default]
213    ByIp,
214    /// Rate limit by API key.
215    ByApiKey,
216}
217
218/// Configuration for rate limiting.
219#[derive(Clone, Debug)]
220pub struct RateLimitConfig {
221    /// Maximum requests per window.
222    pub max_requests: u32,
223    /// Time window duration.
224    pub window: Duration,
225    /// How to identify clients for rate limiting.
226    pub key_strategy: RateLimitKey,
227}
228
229impl Default for RateLimitConfig {
230    fn default() -> Self {
231        Self {
232            max_requests: 100,
233            window: Duration::from_secs(60),
234            key_strategy: RateLimitKey::ByIp,
235        }
236    }
237}
238
239/// Entry in the rate limit store.
240#[derive(Clone)]
241struct RateLimitEntry {
242    count: u32,
243    window_start: Instant,
244}
245
246/// State for rate limiting middleware.
247#[derive(Clone)]
248pub struct RateLimitState {
249    /// The rate limiting configuration.
250    pub config: Arc<RateLimitConfig>,
251    entries: Arc<RwLock<std::collections::HashMap<String, RateLimitEntry>>>,
252}
253
254impl RateLimitState {
255    /// Create a new rate limit state.
256    pub fn new(config: RateLimitConfig) -> Self {
257        Self {
258            config: Arc::new(config),
259            entries: Arc::new(RwLock::new(std::collections::HashMap::new())),
260        }
261    }
262
263    /// Clean up expired entries.
264    pub async fn cleanup(&self) {
265        let now = Instant::now();
266        let window = self.config.window;
267        let mut entries = self.entries.write().await;
268        entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
269    }
270
271    /// Check and increment rate limit for a key.
272    async fn check_and_increment(&self, key: String) -> Result<(u32, u32), (u32, Duration)> {
273        let now = Instant::now();
274        let mut entries = self.entries.write().await;
275
276        let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
277            count: 0,
278            window_start: now,
279        });
280
281        // Reset if window has passed
282        if now.duration_since(entry.window_start) >= self.config.window {
283            entry.count = 0;
284            entry.window_start = now;
285        }
286
287        entry.count += 1;
288
289        if entry.count > self.config.max_requests {
290            let retry_after = self.config.window - now.duration_since(entry.window_start);
291            Err((entry.count, retry_after))
292        } else {
293            Ok((
294                self.config.max_requests - entry.count,
295                self.config.max_requests,
296            ))
297        }
298    }
299}
300
301/// Rate limiting middleware.
302pub async fn rate_limit(
303    State(state): State<RateLimitState>,
304    ConnectInfo(addr): ConnectInfo<SocketAddr>,
305    request: Request,
306    next: Next,
307) -> Result<Response, Response> {
308    let key = match state.config.key_strategy {
309        RateLimitKey::ByIp => addr.ip().to_string(),
310        RateLimitKey::ByApiKey => {
311            // Extract API key from Authorization header
312            request
313                .headers()
314                .get(header::AUTHORIZATION)
315                .and_then(|v| v.to_str().ok())
316                .map(|s| s.trim_start_matches("Bearer ").to_string())
317                .unwrap_or_else(|| addr.ip().to_string())
318        }
319    };
320
321    match state.check_and_increment(key).await {
322        Ok((remaining, limit)) => {
323            let mut response = next.run(request).await;
324            let headers = response.headers_mut();
325
326            // Add rate limit headers
327            headers.insert(
328                "X-RateLimit-Limit",
329                HeaderValue::from_str(&limit.to_string()).unwrap(),
330            );
331            headers.insert(
332                "X-RateLimit-Remaining",
333                HeaderValue::from_str(&remaining.to_string()).unwrap(),
334            );
335
336            Ok(response)
337        }
338        Err((_, retry_after)) => {
339            let retry_secs = retry_after.as_secs().max(1);
340            Err((
341                StatusCode::TOO_MANY_REQUESTS,
342                [
343                    ("Retry-After", retry_secs.to_string()),
344                    ("X-RateLimit-Limit", state.config.max_requests.to_string()),
345                    ("X-RateLimit-Remaining", "0".to_string()),
346                ],
347                "Rate limit exceeded",
348            )
349                .into_response())
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_auth_config_default() {
360        let config = AuthConfig::default();
361        assert!(!config.is_enabled());
362        assert!(config.public_paths.contains("/health"));
363        assert!(config.public_paths.contains("/ready"));
364    }
365
366    #[test]
367    fn test_auth_config_with_keys() {
368        let config = AuthConfig::with_keys(["key1".to_string(), "key2".to_string()]);
369        assert!(config.is_enabled());
370        assert!(config.validate_key("key1"));
371        assert!(config.validate_key("key2"));
372        assert!(!config.validate_key("key3"));
373    }
374
375    #[test]
376    fn test_auth_config_public_paths() {
377        let config = AuthConfig::default().add_public_path("/metrics");
378        assert!(!config.requires_auth("/health"));
379        assert!(!config.requires_auth("/ready"));
380        assert!(!config.requires_auth("/metrics"));
381        assert!(config.requires_auth("/v1/state"));
382    }
383
384    #[test]
385    fn test_rate_limit_config_default() {
386        let config = RateLimitConfig::default();
387        assert_eq!(config.max_requests, 100);
388        assert_eq!(config.window, Duration::from_secs(60));
389    }
390
391    #[tokio::test]
392    async fn test_rate_limit_state() {
393        let state = RateLimitState::new(RateLimitConfig {
394            max_requests: 3,
395            window: Duration::from_secs(60),
396            key_strategy: RateLimitKey::ByIp,
397        });
398
399        // First 3 requests should succeed
400        assert!(state.check_and_increment("test".to_string()).await.is_ok());
401        assert!(state.check_and_increment("test".to_string()).await.is_ok());
402        assert!(state.check_and_increment("test".to_string()).await.is_ok());
403
404        // 4th should fail
405        assert!(state.check_and_increment("test".to_string()).await.is_err());
406
407        // Different key should succeed
408        assert!(state.check_and_increment("other".to_string()).await.is_ok());
409    }
410}