#![forbid(unsafe_code)]
#![deny(missing_docs)]
use super::AppState;
use axum::extract::State;
use axum::http::{
Request,
StatusCode,
};
use axum::http::header;
use axum::middleware::Next;
use axum::response::Response;
use std::str::FromStr;
use std::sync::Arc;
use tracing::debug;
mod basic_auth;
mod basic_auth_config;
use basic_auth::BasicAuth;
pub use basic_auth_config::BasicAuthConfig;
const FALLBACK_PASSWORD_HASH: &str = "$2b$10$xbVccvFGkGUTkQm5gsSr8uI2byLz2t7pY3wgo9RfQy5rt77l6fyDa";
pub async fn validate_credentials<B>(
State(state): State<Arc<AppState>>,
req: Request<B>,
next: Next<B>,
) -> Result<Response, StatusCode> {
debug!("Validating credentials");
let users = match &state.basic_auth_config.basic_auth_users {
Some(users) => users,
None => return Ok(next.run(req).await),
};
let auth_header = req.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let basic_auth = if let Some(auth_header) = auth_header {
BasicAuth::from_str(auth_header)?
}
else {
return Err(StatusCode::UNAUTHORIZED);
};
let user_id = basic_auth.user_id();
let (user_exists, hashed_password) = match users.get(user_id) {
Some(hashed_password) => (true, hashed_password.as_str()),
None => (false, FALLBACK_PASSWORD_HASH),
};
let password = match basic_auth.password() {
Some(password) => password,
None => return Err(StatusCode::UNAUTHORIZED),
};
let validated = match bcrypt::verify(password, hashed_password) {
Ok(b) => b,
Err(e) => {
debug!("Couldn't verify password, bcrypt error: {}", e);
false
},
};
debug!(
"validation status: validated: {}, exists: {}",
validated,
user_exists,
);
if !validated || !user_exists {
return Err(StatusCode::UNAUTHORIZED);
};
let response = next.run(req).await;
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
middleware,
Router,
};
use axum::body::Body;
use axum::http::{
self,
Request,
};
use axum::routing::get;
use std::collections::HashMap;
use tower::ServiceExt;
fn app(state: Arc<AppState>) -> Router {
Router::new()
.route("/", get(|| async { "Test" }))
.route_layer(
middleware::from_fn_with_state(
state,
validate_credentials,
),
)
}
fn get_users_config() -> BasicAuthConfig {
let users = HashMap::from([(
"foo".to_string(),
"$2b$04$nFPE4cwFjOFGUmdp.o2NTuh/blJDaEwikX1qoitVe144TsS2l5whS".to_string(),
)]);
BasicAuthConfig {
basic_auth_users: Some(users),
}
}
#[tokio::test]
async fn validate_credentials_users_no_auth() {
let auth_config = get_users_config();
let data = AppState {
basic_auth_config: auth_config,
index_page: "test".into(),
};
let app = app(Arc::new(data));
let req = Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED)
}
#[tokio::test]
async fn validate_credentials_no_users_no_auth() {
let data = AppState {
basic_auth_config: BasicAuthConfig::default(),
index_page: "test".into(),
};
let app = app(Arc::new(data));
let req = Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
#[tokio::test]
async fn validate_credentials_ok() {
let auth_config = get_users_config();
let data = AppState {
basic_auth_config: auth_config,
index_page: "test".into(),
};
let app = app(Arc::new(data));
let req = Request::builder()
.uri("/")
.header(http::header::AUTHORIZATION, "Basic Zm9vOmJhcg==")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
#[tokio::test]
async fn validate_credentials_unauthorized() {
let auth_config = get_users_config();
let data = AppState {
basic_auth_config: auth_config,
index_page: "test".into(),
};
let app = app(Arc::new(data));
let req = Request::builder()
.uri("/")
.header(http::header::AUTHORIZATION, "Basic YmFkOnBhc3N3b3Jk")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED)
}
#[tokio::test]
async fn validate_credentials_unauthorized_no_user_id() {
let auth_config = get_users_config();
let data = AppState {
basic_auth_config: auth_config,
index_page: "test".into(),
};
let app = app(Arc::new(data));
let req = Request::builder()
.uri("/")
.header(http::header::AUTHORIZATION, "Basic bm9wZTp1c2VyZG9lc250ZXhpc3Q=")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED)
}
}