crates_docs/server/
auth_middleware.rs1#[cfg(feature = "api-key")]
16use crate::server::auth::ApiKeyConfig;
17
18#[cfg(feature = "api-key")]
20pub struct ApiKeyMiddleware {
21 config: ApiKeyConfig,
22}
23
24#[cfg(feature = "api-key")]
25impl ApiKeyMiddleware {
26 #[must_use]
28 pub fn new(config: ApiKeyConfig) -> Self {
29 Self { config }
30 }
31
32 #[must_use]
43 pub fn validate_request(
44 &self,
45 headers: &std::collections::HashMap<String, String>,
46 query_params: Option<&std::collections::HashMap<String, String>>,
47 ) -> bool {
48 if !self.config.enabled {
49 return true;
50 }
51
52 if let Some(key) = headers.get(&self.config.header_name) {
54 return self.config.is_valid_key(key);
55 }
56
57 if self.config.allow_query_param {
59 if let Some(params) = query_params {
60 if let Some(key) = params.get(&self.config.query_param_name) {
61 return self.config.is_valid_key(key);
62 }
63 }
64 }
65
66 false
67 }
68
69 #[must_use]
80 pub fn extract_key(
81 &self,
82 headers: &std::collections::HashMap<String, String>,
83 query_params: Option<&std::collections::HashMap<String, String>>,
84 ) -> Option<String> {
85 if let Some(key) = headers.get(&self.config.header_name) {
87 return Some(key.clone());
88 }
89
90 if self.config.allow_query_param {
92 if let Some(params) = query_params {
93 if let Some(key) = params.get(&self.config.query_param_name) {
94 return Some(key.clone());
95 }
96 }
97 }
98
99 None
100 }
101
102 #[must_use]
104 pub fn is_enabled(&self) -> bool {
105 self.config.enabled
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct AuthError {
112 pub message: String,
114 pub www_authenticate: Option<String>,
116}
117
118impl AuthError {
119 #[must_use]
121 pub fn new(message: impl Into<String>) -> Self {
122 Self {
123 message: message.into(),
124 www_authenticate: None,
125 }
126 }
127
128 #[must_use]
130 pub fn unauthorized() -> Self {
131 Self {
132 message: "Unauthorized: API key required".to_string(),
133 www_authenticate: Some("ApiKey realm=\"crates-docs\"".to_string()),
134 }
135 }
136
137 #[must_use]
139 pub fn invalid_key() -> Self {
140 Self {
141 message: "Unauthorized: Invalid API key".to_string(),
142 www_authenticate: Some(
143 "ApiKey realm=\"crates-docs\" error=\"invalid_key\"".to_string(),
144 ),
145 }
146 }
147}
148
149impl std::fmt::Display for AuthError {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 write!(f, "{}", self.message)
152 }
153}
154
155impl std::error::Error for AuthError {}
156
157#[cfg(not(feature = "api-key"))]
159pub struct NoOpMiddleware;
160
161#[cfg(not(feature = "api-key"))]
162impl NoOpMiddleware {
163 #[must_use]
165 pub fn validate_request(
166 &self,
167 _headers: &std::collections::HashMap<String, String>,
168 _query_params: Option<&std::collections::HashMap<String, String>>,
169 ) -> bool {
170 true
171 }
172
173 #[must_use]
175 pub fn is_enabled(&self) -> bool {
176 false
177 }
178}
179
180#[cfg(test)]
181#[cfg(feature = "api-key")]
182mod tests {
183 use super::*;
184
185 fn create_test_config() -> (ApiKeyConfig, String) {
186 let generator = ApiKeyConfig::default();
187 let generated = generator
188 .generate_key()
189 .expect("failed to generate API key");
190
191 (
192 ApiKeyConfig {
193 enabled: true,
194 keys: vec![generated.hash],
195 ..Default::default()
196 },
197 generated.key,
198 )
199 }
200
201 #[test]
202 fn test_middleware_disabled() {
203 let config = ApiKeyConfig::default();
204 let middleware = ApiKeyMiddleware::new(config);
205
206 let headers = std::collections::HashMap::new();
207 assert!(middleware.validate_request(&headers, None));
208 }
209
210 #[test]
211 fn test_middleware_valid_key_header() {
212 let (config, api_key) = create_test_config();
213 let middleware = ApiKeyMiddleware::new(config);
214
215 let mut headers = std::collections::HashMap::new();
216 headers.insert("X-API-Key".to_string(), api_key);
217
218 assert!(middleware.validate_request(&headers, None));
219 }
220
221 #[test]
222 fn test_middleware_invalid_key_header() {
223 let (config, _) = create_test_config();
224 let middleware = ApiKeyMiddleware::new(config);
225
226 let mut headers = std::collections::HashMap::new();
227 headers.insert("X-API-Key".to_string(), "invalid_key".to_string());
228
229 assert!(!middleware.validate_request(&headers, None));
230 }
231
232 #[test]
233 fn test_middleware_missing_key() {
234 let (config, _) = create_test_config();
235 let middleware = ApiKeyMiddleware::new(config);
236
237 let headers = std::collections::HashMap::new();
238 assert!(!middleware.validate_request(&headers, None));
239 }
240
241 #[test]
242 fn test_middleware_query_param_allowed() {
243 let (mut config, api_key) = create_test_config();
244 config.allow_query_param = true;
245 let middleware = ApiKeyMiddleware::new(config);
246
247 let headers = std::collections::HashMap::new();
248 let mut query_params = std::collections::HashMap::new();
249 query_params.insert("api_key".to_string(), api_key);
250
251 assert!(middleware.validate_request(&headers, Some(&query_params)));
252 }
253
254 #[test]
255 fn test_middleware_query_param_not_allowed() {
256 let (config, api_key) = create_test_config();
257 let middleware = ApiKeyMiddleware::new(config);
258
259 let headers = std::collections::HashMap::new();
260 let mut query_params = std::collections::HashMap::new();
261 query_params.insert("api_key".to_string(), api_key);
262
263 assert!(!middleware.validate_request(&headers, Some(&query_params)));
264 }
265
266 #[test]
267 fn test_extract_key() {
268 let (config, api_key) = create_test_config();
269 let middleware = ApiKeyMiddleware::new(config);
270
271 let mut headers = std::collections::HashMap::new();
272 headers.insert("X-API-Key".to_string(), api_key.clone());
273
274 let key = middleware.extract_key(&headers, None);
275 assert_eq!(key, Some(api_key));
276 }
277
278 #[test]
279 fn test_auth_error() {
280 let error = AuthError::unauthorized();
281 assert_eq!(error.message, "Unauthorized: API key required");
282 assert!(error.www_authenticate.is_some());
283 }
284}