Skip to main content

crates_docs/server/
auth_middleware.rs

1//! Authentication middleware for HTTP requests
2//!
3//! Provides API Key authentication middleware for HTTP/SSE transports.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use crates_docs::server::auth_middleware::ApiKeyMiddleware;
9//! use crates_docs::server::auth::ApiKeyConfig;
10//!
11//! let config = ApiKeyConfig::default();
12//! let middleware = ApiKeyMiddleware::new(config);
13//! ```
14
15#[cfg(feature = "api-key")]
16use crate::server::auth::ApiKeyConfig;
17
18/// API Key authentication middleware
19#[cfg(feature = "api-key")]
20pub struct ApiKeyMiddleware {
21    config: ApiKeyConfig,
22}
23
24#[cfg(feature = "api-key")]
25impl ApiKeyMiddleware {
26    /// Create a new API Key middleware
27    #[must_use]
28    pub fn new(config: ApiKeyConfig) -> Self {
29        Self { config }
30    }
31
32    /// Validate API key from headers or query parameters
33    ///
34    /// # Arguments
35    ///
36    /// * `headers` - HTTP request headers
37    /// * `query_params` - URL query parameters (optional)
38    ///
39    /// # Returns
40    ///
41    /// Returns `true` if authentication is disabled or key is valid
42    #[must_use]
43    pub fn validate_request(
44        &self,
45        headers: &std::collections::HashMap<String, String>,
46        query_params: Option<&std::collections::HashMap<String, String>>,
47    ) -> bool {
48        if !self.config.enabled {
49            return true;
50        }
51
52        // Try to get API key from header first
53        if let Some(key) = headers.get(&self.config.header_name) {
54            return self.config.is_valid_key(key);
55        }
56
57        // Fallback to query parameter if allowed
58        if self.config.allow_query_param {
59            if let Some(params) = query_params {
60                if let Some(key) = params.get(&self.config.query_param_name) {
61                    return self.config.is_valid_key(key);
62                }
63            }
64        }
65
66        false
67    }
68
69    /// Extract API key from request
70    ///
71    /// # Arguments
72    ///
73    /// * `headers` - HTTP request headers
74    /// * `query_params` - URL query parameters (optional)
75    ///
76    /// # Returns
77    ///
78    /// Returns the API key if found
79    #[must_use]
80    pub fn extract_key(
81        &self,
82        headers: &std::collections::HashMap<String, String>,
83        query_params: Option<&std::collections::HashMap<String, String>>,
84    ) -> Option<String> {
85        // Try header first
86        if let Some(key) = headers.get(&self.config.header_name) {
87            return Some(key.clone());
88        }
89
90        // Fallback to query parameter
91        if self.config.allow_query_param {
92            if let Some(params) = query_params {
93                if let Some(key) = params.get(&self.config.query_param_name) {
94                    return Some(key.clone());
95                }
96            }
97        }
98
99        None
100    }
101
102    /// Check if authentication is enabled
103    #[must_use]
104    pub fn is_enabled(&self) -> bool {
105        self.config.enabled
106    }
107}
108
109/// Authentication error response
110#[derive(Debug, Clone)]
111pub struct AuthError {
112    /// Error message
113    pub message: String,
114    /// WWW-Authenticate header value
115    pub www_authenticate: Option<String>,
116}
117
118impl AuthError {
119    /// Create a new authentication error
120    #[must_use]
121    pub fn new(message: impl Into<String>) -> Self {
122        Self {
123            message: message.into(),
124            www_authenticate: None,
125        }
126    }
127
128    /// Create unauthorized error
129    #[must_use]
130    pub fn unauthorized() -> Self {
131        Self {
132            message: "Unauthorized: API key required".to_string(),
133            www_authenticate: Some("ApiKey realm=\"crates-docs\"".to_string()),
134        }
135    }
136
137    /// Create invalid key error
138    #[must_use]
139    pub fn invalid_key() -> Self {
140        Self {
141            message: "Unauthorized: Invalid API key".to_string(),
142            www_authenticate: Some(
143                "ApiKey realm=\"crates-docs\" error=\"invalid_key\"".to_string(),
144            ),
145        }
146    }
147}
148
149impl std::fmt::Display for AuthError {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        write!(f, "{}", self.message)
152    }
153}
154
155impl std::error::Error for AuthError {}
156
157/// No-op middleware when API Key feature is disabled
158#[cfg(not(feature = "api-key"))]
159pub struct NoOpMiddleware;
160
161#[cfg(not(feature = "api-key"))]
162impl NoOpMiddleware {
163    /// Always returns true (no authentication)
164    #[must_use]
165    pub fn validate_request(
166        &self,
167        _headers: &std::collections::HashMap<String, String>,
168        _query_params: Option<&std::collections::HashMap<String, String>>,
169    ) -> bool {
170        true
171    }
172
173    /// Always returns true (no authentication)
174    #[must_use]
175    pub fn is_enabled(&self) -> bool {
176        false
177    }
178}
179
180#[cfg(test)]
181#[cfg(feature = "api-key")]
182mod tests {
183    use super::*;
184
185    fn create_test_config() -> (ApiKeyConfig, String) {
186        let generator = ApiKeyConfig::default();
187        let generated = generator
188            .generate_key()
189            .expect("failed to generate API key");
190
191        (
192            ApiKeyConfig {
193                enabled: true,
194                keys: vec![generated.hash],
195                ..Default::default()
196            },
197            generated.key,
198        )
199    }
200
201    #[test]
202    fn test_middleware_disabled() {
203        let config = ApiKeyConfig::default();
204        let middleware = ApiKeyMiddleware::new(config);
205
206        let headers = std::collections::HashMap::new();
207        assert!(middleware.validate_request(&headers, None));
208    }
209
210    #[test]
211    fn test_middleware_valid_key_header() {
212        let (config, api_key) = create_test_config();
213        let middleware = ApiKeyMiddleware::new(config);
214
215        let mut headers = std::collections::HashMap::new();
216        headers.insert("X-API-Key".to_string(), api_key);
217
218        assert!(middleware.validate_request(&headers, None));
219    }
220
221    #[test]
222    fn test_middleware_invalid_key_header() {
223        let (config, _) = create_test_config();
224        let middleware = ApiKeyMiddleware::new(config);
225
226        let mut headers = std::collections::HashMap::new();
227        headers.insert("X-API-Key".to_string(), "invalid_key".to_string());
228
229        assert!(!middleware.validate_request(&headers, None));
230    }
231
232    #[test]
233    fn test_middleware_missing_key() {
234        let (config, _) = create_test_config();
235        let middleware = ApiKeyMiddleware::new(config);
236
237        let headers = std::collections::HashMap::new();
238        assert!(!middleware.validate_request(&headers, None));
239    }
240
241    #[test]
242    fn test_middleware_query_param_allowed() {
243        let (mut config, api_key) = create_test_config();
244        config.allow_query_param = true;
245        let middleware = ApiKeyMiddleware::new(config);
246
247        let headers = std::collections::HashMap::new();
248        let mut query_params = std::collections::HashMap::new();
249        query_params.insert("api_key".to_string(), api_key);
250
251        assert!(middleware.validate_request(&headers, Some(&query_params)));
252    }
253
254    #[test]
255    fn test_middleware_query_param_not_allowed() {
256        let (config, api_key) = create_test_config();
257        let middleware = ApiKeyMiddleware::new(config);
258
259        let headers = std::collections::HashMap::new();
260        let mut query_params = std::collections::HashMap::new();
261        query_params.insert("api_key".to_string(), api_key);
262
263        assert!(!middleware.validate_request(&headers, Some(&query_params)));
264    }
265
266    #[test]
267    fn test_extract_key() {
268        let (config, api_key) = create_test_config();
269        let middleware = ApiKeyMiddleware::new(config);
270
271        let mut headers = std::collections::HashMap::new();
272        headers.insert("X-API-Key".to_string(), api_key.clone());
273
274        let key = middleware.extract_key(&headers, None);
275        assert_eq!(key, Some(api_key));
276    }
277
278    #[test]
279    fn test_auth_error() {
280        let error = AuthError::unauthorized();
281        assert_eq!(error.message, "Unauthorized: API key required");
282        assert!(error.www_authenticate.is_some());
283    }
284}