1use std::sync::Arc;
6
7use axum::extract::Request;
8use axum::http::StatusCode;
9use axum::middleware::Next;
10use axum::response::{IntoResponse, Response};
11
12#[derive(Debug, Clone)]
14pub struct AuthConfig {
15 pub enabled: bool,
17 api_key: Option<String>,
19}
20
21impl AuthConfig {
22 pub const fn disabled() -> Self {
24 Self {
25 enabled: false,
26 api_key: None,
27 }
28 }
29
30 pub const fn with_api_key(api_key: String) -> Self {
32 Self {
33 enabled: true,
34 api_key: Some(api_key),
35 }
36 }
37
38 pub fn validate_key(&self, provided_key: &str) -> bool {
40 if !self.enabled {
41 return true;
42 }
43
44 match &self.api_key {
45 Some(key) => constant_time_compare(key, provided_key),
46 None => false,
47 }
48 }
49
50 pub const fn is_required(&self) -> bool {
52 self.enabled
53 }
54
55 pub fn api_key(&self) -> Option<&str> {
57 self.api_key.as_deref()
58 }
59}
60
61impl Default for AuthConfig {
62 fn default() -> Self {
63 Self::disabled()
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum AuthError {
70 MissingCredentials,
72 InvalidCredentials,
74 MalformedHeader,
76}
77
78impl std::fmt::Display for AuthError {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::MissingCredentials => write!(f, "Authentication required"),
82 Self::InvalidCredentials => write!(f, "Invalid API key"),
83 Self::MalformedHeader => write!(f, "Malformed authorization header"),
84 }
85 }
86}
87
88impl std::error::Error for AuthError {}
89
90pub type AuthResult<T> = Result<T, AuthError>;
92
93pub fn extract_from_header(header_value: &str) -> AuthResult<String> {
100 let header = header_value.trim();
101
102 if header.is_empty() {
103 return Err(AuthError::MissingCredentials);
104 }
105
106 if let Some(rest) = header.strip_prefix("Bearer ") {
108 let key = rest.trim();
109 if key.is_empty() {
110 return Err(AuthError::MalformedHeader);
111 }
112 return Ok(key.to_string());
113 }
114
115 if let Some(rest) = header.strip_prefix("Bearer\t") {
117 let key = rest.trim();
118 if key.is_empty() {
119 return Err(AuthError::MalformedHeader);
120 }
121 return Ok(key.to_string());
122 }
123
124 if header == "Bearer" {
126 return Err(AuthError::MalformedHeader);
127 }
128
129 if let Some(rest) = header.strip_prefix("ApiKey ") {
131 let key = rest.trim();
132 if key.is_empty() {
133 return Err(AuthError::MalformedHeader);
134 }
135 return Ok(key.to_string());
136 }
137
138 if let Some(rest) = header.strip_prefix("ApiKey\t") {
140 let key = rest.trim();
141 if key.is_empty() {
142 return Err(AuthError::MalformedHeader);
143 }
144 return Ok(key.to_string());
145 }
146
147 if header == "ApiKey" {
149 return Err(AuthError::MalformedHeader);
150 }
151
152 Ok(header.to_string())
154}
155
156pub fn extract_from_ws_protocol(header: &str) -> AuthResult<String> {
162 for protocol in header.split(',') {
163 let protocol = protocol.trim();
164 if let Some(key) = protocol.strip_prefix("varpulis-auth.") {
165 if !key.is_empty() {
166 return Ok(key.to_string());
167 }
168 }
169 }
170 Err(AuthError::MissingCredentials)
171}
172
173pub fn extract_from_query(query: &str) -> AuthResult<String> {
177 if query.is_empty() {
178 return Err(AuthError::MissingCredentials);
179 }
180
181 for pair in query.split('&') {
183 let mut parts = pair.splitn(2, '=');
184 let key = parts.next().unwrap_or("");
185 let value = parts.next().unwrap_or("");
186
187 if (key == "api_key" || key == "token") && !value.is_empty() {
188 let decoded = url_decode(value);
190 return Ok(decoded);
191 }
192 }
193
194 Err(AuthError::MissingCredentials)
195}
196
197fn url_decode(s: &str) -> String {
199 let mut result = String::with_capacity(s.len());
200 let mut chars = s.chars();
201
202 while let Some(c) = chars.next() {
203 if c == '%' {
204 let hex: String = chars.by_ref().take(2).collect();
206 if hex.len() == 2 {
207 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
208 result.push(byte as char);
209 continue;
210 }
211 }
212 result.push('%');
214 result.push_str(&hex);
215 } else if c == '+' {
216 result.push(' ');
217 } else {
218 result.push(c);
219 }
220 }
221
222 result
223}
224
225pub fn constant_time_compare(a: &str, b: &str) -> bool {
230 varpulis_core::security::constant_time_compare(a, b)
231}
232
233pub fn generate_api_key() -> String {
238 use rand::RngExt;
239
240 let mut rng = rand::rng();
241 let mut key = String::with_capacity(32);
242 const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
243
244 for _ in 0..32 {
245 let idx = rng.random_range(0..CHARSET.len());
246 key.push(CHARSET[idx] as char);
247 }
248
249 key
250}
251
252pub fn auth_middleware(config: Arc<AuthConfig>) -> impl tower::Layer<axum::routing::Route> + Clone {
263 axum::middleware::from_fn_with_state::<_, _, ()>(config, auth_middleware_fn)
264}
265
266#[derive(Debug, Clone)]
268pub struct AuthState {
269 pub config: Arc<AuthConfig>,
270 pub oauth_state: Option<crate::oauth::SharedOAuthState>,
271}
272
273pub fn auth_middleware_with_jwt(
275 config: Arc<AuthConfig>,
276 oauth_state: Option<crate::oauth::SharedOAuthState>,
277) -> impl tower::Layer<axum::routing::Route> + Clone {
278 let state = AuthState {
279 config,
280 oauth_state,
281 };
282 axum::middleware::from_fn_with_state::<_, _, ()>(state, auth_middleware_jwt_fn)
283}
284
285pub async fn auth_middleware_fn(
287 axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
288 req: Request,
289 next: Next,
290) -> Result<Response, AuthRejection> {
291 let state = AuthState {
292 config,
293 oauth_state: None,
294 };
295 check_auth(&state, &req).await?;
296 Ok(next.run(req).await)
297}
298
299async fn auth_middleware_jwt_fn(
301 axum::extract::State(state): axum::extract::State<AuthState>,
302 req: Request,
303 next: Next,
304) -> Result<Response, AuthRejection> {
305 check_auth(&state, &req).await?;
306 Ok(next.run(req).await)
307}
308
309pub async fn check_auth(state: &AuthState, req: &Request) -> Result<(), AuthRejection> {
311 check_auth_from_parts(state, req.headers(), req.uri()).await
312}
313
314pub async fn check_auth_from_parts(
319 state: &AuthState,
320 headers: &axum::http::HeaderMap,
321 uri: &axum::http::Uri,
322) -> Result<(), AuthRejection> {
323 let config = &state.config;
324 let oauth = &state.oauth_state;
325
326 if !config.is_required() {
328 return Ok(());
329 }
330
331 let auth_header = headers
332 .get("authorization")
333 .and_then(|v| v.to_str().ok())
334 .map(|s| s.to_string());
335 let cookie_header = headers
336 .get("cookie")
337 .and_then(|v| v.to_str().ok())
338 .map(|s| s.to_string());
339 let ws_protocol = headers
340 .get("sec-websocket-protocol")
341 .and_then(|v| v.to_str().ok())
342 .map(|s| s.to_string());
343 let query = uri.query().unwrap_or("").to_string();
344
345 if let Some(header) = &auth_header {
347 match extract_from_header(header) {
348 Ok(key) if config.validate_key(&key) => return Ok(()),
349 Ok(_) => return Err(AuthRejection::InvalidCredentials),
350 Err(AuthError::MalformedHeader) => return Err(AuthRejection::MalformedHeader),
351 Err(_) => {} }
353 }
354
355 if let Some(ref cookie) = cookie_header {
357 if let Some(jwt) = crate::oauth::extract_jwt_from_cookie(cookie) {
358 if let Some(ref state) = oauth {
359 let hash = crate::oauth::token_hash(&jwt);
361 if !state.sessions.read().await.is_revoked(&hash)
362 && crate::oauth::verify_jwt(&state.config, &jwt).is_ok()
363 {
364 return Ok(());
365 }
366 }
367 }
368 }
369
370 if let Some(ref header) = auth_header {
372 if let Some(token) = header.strip_prefix("Bearer ") {
373 let token = token.trim();
374 if !token.is_empty() {
375 if let Some(ref state) = oauth {
376 let hash = crate::oauth::token_hash(token);
377 if !state.sessions.read().await.is_revoked(&hash)
378 && crate::oauth::verify_jwt(&state.config, token).is_ok()
379 {
380 return Ok(());
381 }
382 }
383 }
384 }
385 }
386
387 if let Some(ref protocol) = ws_protocol {
389 match extract_from_ws_protocol(protocol) {
390 Ok(key) if config.validate_key(&key) => return Ok(()),
391 Ok(_) => return Err(AuthRejection::InvalidCredentials),
392 Err(_) => {} }
394 }
395
396 match extract_from_query(&query) {
398 Ok(key) if config.validate_key(&key) => Ok(()),
399 Ok(_) => Err(AuthRejection::InvalidCredentials),
400 Err(_) => Err(AuthRejection::MissingCredentials),
401 }
402}
403
404#[derive(Debug)]
406pub enum AuthRejection {
407 MissingCredentials,
408 InvalidCredentials,
409 MalformedHeader,
410}
411
412impl IntoResponse for AuthRejection {
413 fn into_response(self) -> Response {
414 let (code, message) = match self {
415 Self::MissingCredentials => (StatusCode::UNAUTHORIZED, "Authentication required"),
416 Self::InvalidCredentials => (StatusCode::UNAUTHORIZED, "Invalid API key"),
417 Self::MalformedHeader => (StatusCode::BAD_REQUEST, "Malformed authorization header"),
418 };
419 (code, axum::Json(serde_json::json!({ "error": message }))).into_response()
420 }
421}
422
423#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
436 fn test_auth_config_disabled() {
437 let config = AuthConfig::disabled();
438 assert!(!config.enabled);
439 assert!(!config.is_required());
440 }
441
442 #[test]
443 fn test_auth_config_with_api_key() {
444 let config = AuthConfig::with_api_key("secret123".to_string());
445 assert!(config.enabled);
446 assert!(config.is_required());
447 }
448
449 #[test]
450 fn test_auth_config_validate_key_disabled() {
451 let config = AuthConfig::disabled();
452 assert!(config.validate_key("anything"));
454 assert!(config.validate_key(""));
455 }
456
457 #[test]
458 fn test_auth_config_validate_key_correct() {
459 let config = AuthConfig::with_api_key("secret123".to_string());
460 assert!(config.validate_key("secret123"));
461 }
462
463 #[test]
464 fn test_auth_config_validate_key_incorrect() {
465 let config = AuthConfig::with_api_key("secret123".to_string());
466 assert!(!config.validate_key("wrong"));
467 assert!(!config.validate_key(""));
468 assert!(!config.validate_key("secret1234")); assert!(!config.validate_key("secret12")); }
471
472 #[test]
473 fn test_auth_config_default() {
474 let config = AuthConfig::default();
475 assert!(!config.enabled);
476 }
477
478 #[test]
483 fn test_extract_from_header_bearer() {
484 let result = extract_from_header("Bearer my-api-key");
485 assert_eq!(result, Ok("my-api-key".to_string()));
486 }
487
488 #[test]
489 fn test_extract_from_header_bearer_with_spaces() {
490 let result = extract_from_header(" Bearer my-api-key ");
491 assert_eq!(result, Ok("my-api-key".to_string()));
492 }
493
494 #[test]
495 fn test_extract_from_header_apikey() {
496 let result = extract_from_header("ApiKey secret-key");
497 assert_eq!(result, Ok("secret-key".to_string()));
498 }
499
500 #[test]
501 fn test_extract_from_header_raw() {
502 let result = extract_from_header("raw-key-without-prefix");
503 assert_eq!(result, Ok("raw-key-without-prefix".to_string()));
504 }
505
506 #[test]
507 fn test_extract_from_header_empty() {
508 let result = extract_from_header("");
509 assert_eq!(result, Err(AuthError::MissingCredentials));
510 }
511
512 #[test]
513 fn test_extract_from_header_bearer_empty_key() {
514 let result = extract_from_header("Bearer ");
515 assert_eq!(result, Err(AuthError::MalformedHeader));
516 }
517
518 #[test]
519 fn test_extract_from_header_apikey_empty_key() {
520 let result = extract_from_header("ApiKey ");
521 assert_eq!(result, Err(AuthError::MalformedHeader));
522 }
523
524 #[test]
529 fn test_extract_from_query_api_key() {
530 let result = extract_from_query("api_key=my-secret");
531 assert_eq!(result, Ok("my-secret".to_string()));
532 }
533
534 #[test]
535 fn test_extract_from_query_token() {
536 let result = extract_from_query("token=my-token");
537 assert_eq!(result, Ok("my-token".to_string()));
538 }
539
540 #[test]
541 fn test_extract_from_query_with_other_params() {
542 let result = extract_from_query("foo=bar&api_key=secret&baz=qux");
543 assert_eq!(result, Ok("secret".to_string()));
544 }
545
546 #[test]
547 fn test_extract_from_query_empty() {
548 let result = extract_from_query("");
549 assert_eq!(result, Err(AuthError::MissingCredentials));
550 }
551
552 #[test]
553 fn test_extract_from_query_no_key() {
554 let result = extract_from_query("foo=bar&baz=qux");
555 assert_eq!(result, Err(AuthError::MissingCredentials));
556 }
557
558 #[test]
559 fn test_extract_from_query_empty_value() {
560 let result = extract_from_query("api_key=");
561 assert_eq!(result, Err(AuthError::MissingCredentials));
562 }
563
564 #[test]
565 fn test_extract_from_query_url_encoded() {
566 let result = extract_from_query("api_key=key%20with%20spaces");
567 assert_eq!(result, Ok("key with spaces".to_string()));
568 }
569
570 #[test]
571 fn test_extract_from_query_plus_sign() {
572 let result = extract_from_query("api_key=key+with+plus");
573 assert_eq!(result, Ok("key with plus".to_string()));
574 }
575
576 #[test]
581 fn test_extract_from_ws_protocol_valid() {
582 let result = extract_from_ws_protocol("varpulis-v1, varpulis-auth.my-secret-key");
583 assert_eq!(result, Ok("my-secret-key".to_string()));
584 }
585
586 #[test]
587 fn test_extract_from_ws_protocol_only_auth() {
588 let result = extract_from_ws_protocol("varpulis-auth.abc123");
589 assert_eq!(result, Ok("abc123".to_string()));
590 }
591
592 #[test]
593 fn test_extract_from_ws_protocol_no_auth() {
594 let result = extract_from_ws_protocol("varpulis-v1");
595 assert!(result.is_err());
596 }
597
598 #[test]
599 fn test_extract_from_ws_protocol_empty() {
600 let result = extract_from_ws_protocol("");
601 assert!(result.is_err());
602 }
603
604 #[test]
605 fn test_extract_from_ws_protocol_empty_key() {
606 let result = extract_from_ws_protocol("varpulis-auth.");
607 assert!(result.is_err());
608 }
609
610 #[test]
615 fn test_url_decode_plain() {
616 assert_eq!(url_decode("hello"), "hello");
617 }
618
619 #[test]
620 fn test_url_decode_spaces() {
621 assert_eq!(url_decode("hello%20world"), "hello world");
622 }
623
624 #[test]
625 fn test_url_decode_plus() {
626 assert_eq!(url_decode("hello+world"), "hello world");
627 }
628
629 #[test]
630 fn test_url_decode_special_chars() {
631 assert_eq!(url_decode("%21%40%23"), "!@#");
632 }
633
634 #[test]
639 fn test_constant_time_compare_equal() {
640 assert!(constant_time_compare("abc", "abc"));
641 assert!(constant_time_compare("", ""));
642 assert!(constant_time_compare(
643 "longer-string-123",
644 "longer-string-123"
645 ));
646 }
647
648 #[test]
649 fn test_constant_time_compare_not_equal() {
650 assert!(!constant_time_compare("abc", "abd"));
651 assert!(!constant_time_compare("abc", "ab"));
652 assert!(!constant_time_compare("abc", "abcd"));
653 assert!(!constant_time_compare("", "a"));
654 }
655
656 #[test]
661 fn test_generate_api_key_length() {
662 let key = generate_api_key();
663 assert_eq!(key.len(), 32);
664 }
665
666 #[test]
667 fn test_generate_api_key_alphanumeric() {
668 let key = generate_api_key();
669 assert!(key.chars().all(|c| c.is_ascii_alphanumeric()));
670 }
671
672 #[test]
673 fn test_generate_api_key_unique() {
674 let key1 = generate_api_key();
675 std::thread::sleep(std::time::Duration::from_millis(1));
676 let key2 = generate_api_key();
677 assert_ne!(key1, key2);
678 }
679
680 #[test]
685 fn test_auth_error_display_missing() {
686 let err = AuthError::MissingCredentials;
687 assert_eq!(format!("{err}"), "Authentication required");
688 }
689
690 #[test]
691 fn test_auth_error_display_invalid() {
692 let err = AuthError::InvalidCredentials;
693 assert_eq!(format!("{err}"), "Invalid API key");
694 }
695
696 #[test]
697 fn test_auth_error_display_malformed() {
698 let err = AuthError::MalformedHeader;
699 assert_eq!(format!("{err}"), "Malformed authorization header");
700 }
701
702 #[tokio::test]
707 async fn test_with_auth_disabled() {
708 let config = Arc::new(AuthConfig::disabled());
709 let state = AuthState {
710 config,
711 oauth_state: None,
712 };
713 let req = Request::builder()
715 .uri("/")
716 .body(axum::body::Body::empty())
717 .unwrap();
718 let result = check_auth(&state, &req).await;
719 assert!(result.is_ok());
720 }
721
722 #[tokio::test]
723 async fn test_with_auth_valid_header() {
724 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
725 let state = AuthState {
726 config,
727 oauth_state: None,
728 };
729 let req = Request::builder()
730 .uri("/")
731 .header("authorization", "Bearer secret")
732 .body(axum::body::Body::empty())
733 .unwrap();
734 let result = check_auth(&state, &req).await;
735 assert!(result.is_ok());
736 }
737
738 #[tokio::test]
739 async fn test_with_auth_valid_query() {
740 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
741 let state = AuthState {
742 config,
743 oauth_state: None,
744 };
745 let req = Request::builder()
746 .uri("/?api_key=secret")
747 .body(axum::body::Body::empty())
748 .unwrap();
749 let result = check_auth(&state, &req).await;
750 assert!(result.is_ok());
751 }
752
753 #[tokio::test]
754 async fn test_with_auth_invalid_key() {
755 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
756 let state = AuthState {
757 config,
758 oauth_state: None,
759 };
760 let req = Request::builder()
761 .uri("/")
762 .header("authorization", "Bearer wrong")
763 .body(axum::body::Body::empty())
764 .unwrap();
765 let result = check_auth(&state, &req).await;
766 assert!(matches!(result, Err(AuthRejection::InvalidCredentials)));
767 }
768
769 #[tokio::test]
770 async fn test_with_auth_missing_credentials() {
771 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
772 let state = AuthState {
773 config,
774 oauth_state: None,
775 };
776 let req = Request::builder()
777 .uri("/")
778 .body(axum::body::Body::empty())
779 .unwrap();
780 let result = check_auth(&state, &req).await;
781 assert!(matches!(result, Err(AuthRejection::MissingCredentials)));
782 }
783}