bytedocs_rs/core/
auth.rs

1use crate::core::types::AuthConfig;
2use crate::core::session_auth::SessionAuthMiddleware;
3use axum::{
4    extract::Request,
5    http::{HeaderMap, StatusCode},
6    middleware::Next,
7    response::{IntoResponse, Response},
8    Json,
9};
10use base64::prelude::*;
11use serde_json::json;
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15pub async fn auth_middleware(
16    config: Option<&AuthConfig>,
17    request: Request,
18    next: Next,
19) -> Result<Response, StatusCode> {
20    if let Some(config) = config {
21        if !config.enabled {
22            return Ok(next.run(request).await);
23        }
24
25        if config.r#type == "session" {
26            // Handle session auth separately
27            let session_auth = SessionAuthMiddleware::new(config).await
28                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
29            return session_auth.handle(request, next).await;
30        }
31
32        if let Err(err) = authenticate_request(&request, config) {
33            return Ok(handle_auth_error(config, &err).into_response());
34        }
35    }
36
37    Ok(next.run(request).await)
38}
39
40fn authenticate_request(request: &Request, config: &AuthConfig) -> Result<(), String> {
41    match config.r#type.as_str() {
42        "basic" => authenticate_basic(request, config),
43        "api_key" => authenticate_api_key(request, config),
44        "bearer" => authenticate_bearer(request, config),
45        _ => Err(format!("unsupported auth type: {}", config.r#type)),
46    }
47}
48
49fn authenticate_basic(request: &Request, config: &AuthConfig) -> Result<(), String> {
50    let headers = request.headers();
51    let auth_header = headers
52        .get("authorization")
53        .ok_or("missing Authorization header")?
54        .to_str()
55        .map_err(|_| "invalid Authorization header")?;
56
57    if !auth_header.starts_with("Basic ") {
58        return Err("invalid Authorization header format".to_string());
59    }
60
61    let payload = BASE64_STANDARD
62        .decode(&auth_header[6..])
63        .map_err(|_| "invalid base64 in Authorization header")?;
64
65    let credentials = String::from_utf8(payload)
66        .map_err(|_| "invalid UTF-8 in credentials")?;
67
68    let parts: Vec<&str> = credentials.splitn(2, ':').collect();
69    if parts.len() != 2 {
70        return Err("invalid credential format".to_string());
71    }
72
73    let (username, password) = (parts[0], parts[1]);
74
75    if username.as_bytes().ct_eq(config.username.as_bytes()).unwrap_u8() != 1
76        || password.as_bytes().ct_eq(config.password.as_bytes()).unwrap_u8() != 1
77    {
78        return Err("invalid credentials".to_string());
79    }
80
81    Ok(())
82}
83
84fn authenticate_api_key(request: &Request, config: &AuthConfig) -> Result<(), String> {
85    let header_name = if config.api_key_header.is_empty() {
86        "x-api-key"
87    } else {
88        &config.api_key_header
89    };
90
91    let headers = request.headers();
92    let api_key = headers
93        .get(header_name)
94        .ok_or_else(|| format!("missing {} header", header_name))?
95        .to_str()
96        .map_err(|_| "invalid API key header")?;
97
98    if api_key.as_bytes().ct_eq(config.api_key.as_bytes()).unwrap_u8() != 1 {
99        return Err("invalid API key".to_string());
100    }
101
102    Ok(())
103}
104
105fn authenticate_bearer(request: &Request, config: &AuthConfig) -> Result<(), String> {
106    let headers = request.headers();
107    let auth_header = headers
108        .get("authorization")
109        .ok_or("missing Authorization header")?
110        .to_str()
111        .map_err(|_| "invalid Authorization header")?;
112
113    if !auth_header.starts_with("Bearer ") {
114        return Err("invalid Authorization header format".to_string());
115    }
116
117    let token = auth_header[7..].trim();
118    if token.is_empty() {
119        return Err("missing bearer token".to_string());
120    }
121
122    if token.as_bytes().ct_eq(config.api_key.as_bytes()).unwrap_u8() != 1 {
123        return Err("invalid bearer token".to_string());
124    }
125
126    Ok(())
127}
128
129fn handle_auth_error(config: &AuthConfig, _error: &str) -> impl IntoResponse {
130    let mut headers = HeaderMap::new();
131
132    match config.r#type.as_str() {
133        "basic" => {
134            let realm = if config.realm.is_empty() {
135                "Bytedocs API Documentation"
136            } else {
137                &config.realm
138            };
139            headers.insert(
140                "www-authenticate",
141                format!(r#"Basic realm="{}""#, realm).parse().unwrap(),
142            );
143        }
144        "bearer" => {
145            headers.insert(
146                "www-authenticate",
147                r#"Bearer realm="Bytedocs API Documentation""#.parse().unwrap(),
148            );
149        }
150        _ => {}
151    }
152
153    let mut error_response = HashMap::new();
154    error_response.insert("error", "Authentication required");
155    error_response.insert("message", "Access to this resource requires authentication");
156    error_response.insert("type", &config.r#type);
157
158    let hint = match config.r#type.as_str() {
159        "basic" => "Use HTTP Basic Authentication with username and password",
160        "api_key" => {
161            let header_name = if config.api_key_header.is_empty() {
162                "X-API-Key"
163            } else {
164                &config.api_key_header
165            };
166            return (
167                StatusCode::UNAUTHORIZED,
168                headers,
169                Json(json!({
170                    "error": "Authentication required",
171                    "message": "Access to this resource requires authentication",
172                    "type": config.r#type,
173                    "hint": format!("Provide API key in {} header", header_name)
174                })),
175            );
176        }
177        "bearer" => "Use Authorization: Bearer <token> header",
178        _ => "Authentication required",
179    };
180
181    error_response.insert("hint", hint);
182
183    (
184        StatusCode::UNAUTHORIZED,
185        headers,
186        Json(json!(error_response)),
187    )
188}