1use crate::core::types::AuthConfig;
2use crate::core::session_auth::SessionAuthMiddleware;
3use axum::{
4 extract::Request,
5 http::{HeaderMap, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8 Json,
9};
10use base64::prelude::*;
11use serde_json::json;
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15pub async fn auth_middleware(
16 config: Option<&AuthConfig>,
17 request: Request,
18 next: Next,
19) -> Result<Response, StatusCode> {
20 if let Some(config) = config {
21 if !config.enabled {
22 return Ok(next.run(request).await);
23 }
24
25 if config.r#type == "session" {
26 let session_auth = SessionAuthMiddleware::new(config).await
28 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
29 return session_auth.handle(request, next).await;
30 }
31
32 if let Err(err) = authenticate_request(&request, config) {
33 return Ok(handle_auth_error(config, &err).into_response());
34 }
35 }
36
37 Ok(next.run(request).await)
38}
39
40fn authenticate_request(request: &Request, config: &AuthConfig) -> Result<(), String> {
41 match config.r#type.as_str() {
42 "basic" => authenticate_basic(request, config),
43 "api_key" => authenticate_api_key(request, config),
44 "bearer" => authenticate_bearer(request, config),
45 _ => Err(format!("unsupported auth type: {}", config.r#type)),
46 }
47}
48
49fn authenticate_basic(request: &Request, config: &AuthConfig) -> Result<(), String> {
50 let headers = request.headers();
51 let auth_header = headers
52 .get("authorization")
53 .ok_or("missing Authorization header")?
54 .to_str()
55 .map_err(|_| "invalid Authorization header")?;
56
57 if !auth_header.starts_with("Basic ") {
58 return Err("invalid Authorization header format".to_string());
59 }
60
61 let payload = BASE64_STANDARD
62 .decode(&auth_header[6..])
63 .map_err(|_| "invalid base64 in Authorization header")?;
64
65 let credentials = String::from_utf8(payload)
66 .map_err(|_| "invalid UTF-8 in credentials")?;
67
68 let parts: Vec<&str> = credentials.splitn(2, ':').collect();
69 if parts.len() != 2 {
70 return Err("invalid credential format".to_string());
71 }
72
73 let (username, password) = (parts[0], parts[1]);
74
75 if username.as_bytes().ct_eq(config.username.as_bytes()).unwrap_u8() != 1
76 || password.as_bytes().ct_eq(config.password.as_bytes()).unwrap_u8() != 1
77 {
78 return Err("invalid credentials".to_string());
79 }
80
81 Ok(())
82}
83
84fn authenticate_api_key(request: &Request, config: &AuthConfig) -> Result<(), String> {
85 let header_name = if config.api_key_header.is_empty() {
86 "x-api-key"
87 } else {
88 &config.api_key_header
89 };
90
91 let headers = request.headers();
92 let api_key = headers
93 .get(header_name)
94 .ok_or_else(|| format!("missing {} header", header_name))?
95 .to_str()
96 .map_err(|_| "invalid API key header")?;
97
98 if api_key.as_bytes().ct_eq(config.api_key.as_bytes()).unwrap_u8() != 1 {
99 return Err("invalid API key".to_string());
100 }
101
102 Ok(())
103}
104
105fn authenticate_bearer(request: &Request, config: &AuthConfig) -> Result<(), String> {
106 let headers = request.headers();
107 let auth_header = headers
108 .get("authorization")
109 .ok_or("missing Authorization header")?
110 .to_str()
111 .map_err(|_| "invalid Authorization header")?;
112
113 if !auth_header.starts_with("Bearer ") {
114 return Err("invalid Authorization header format".to_string());
115 }
116
117 let token = auth_header[7..].trim();
118 if token.is_empty() {
119 return Err("missing bearer token".to_string());
120 }
121
122 if token.as_bytes().ct_eq(config.api_key.as_bytes()).unwrap_u8() != 1 {
123 return Err("invalid bearer token".to_string());
124 }
125
126 Ok(())
127}
128
129fn handle_auth_error(config: &AuthConfig, _error: &str) -> impl IntoResponse {
130 let mut headers = HeaderMap::new();
131
132 match config.r#type.as_str() {
133 "basic" => {
134 let realm = if config.realm.is_empty() {
135 "Bytedocs API Documentation"
136 } else {
137 &config.realm
138 };
139 headers.insert(
140 "www-authenticate",
141 format!(r#"Basic realm="{}""#, realm).parse().unwrap(),
142 );
143 }
144 "bearer" => {
145 headers.insert(
146 "www-authenticate",
147 r#"Bearer realm="Bytedocs API Documentation""#.parse().unwrap(),
148 );
149 }
150 _ => {}
151 }
152
153 let mut error_response = HashMap::new();
154 error_response.insert("error", "Authentication required");
155 error_response.insert("message", "Access to this resource requires authentication");
156 error_response.insert("type", &config.r#type);
157
158 let hint = match config.r#type.as_str() {
159 "basic" => "Use HTTP Basic Authentication with username and password",
160 "api_key" => {
161 let header_name = if config.api_key_header.is_empty() {
162 "X-API-Key"
163 } else {
164 &config.api_key_header
165 };
166 return (
167 StatusCode::UNAUTHORIZED,
168 headers,
169 Json(json!({
170 "error": "Authentication required",
171 "message": "Access to this resource requires authentication",
172 "type": config.r#type,
173 "hint": format!("Provide API key in {} header", header_name)
174 })),
175 );
176 }
177 "bearer" => "Use Authorization: Bearer <token> header",
178 _ => "Authentication required",
179 };
180
181 error_response.insert("hint", hint);
182
183 (
184 StatusCode::UNAUTHORIZED,
185 headers,
186 Json(json!(error_response)),
187 )
188}