use crate::{admin::app_state::AppState, shared::HttpResult};
use axum::{
body::Body,
extract::{Request, State},
http::Response,
response::IntoResponse,
};
use base64::Engine;
pub async fn dav_handler(
State(state): State<AppState>,
req: Request<Body>,
) -> HttpResult<impl IntoResponse> {
if !is_valid_authorization_header(req.headers(), &state.admin_password) {
return Ok(Response::builder()
.status(401)
.header("WWW-Authenticate", "Basic") .body(Body::from("Unauthorized"))
.expect("This response should always be valid"));
}
let dav_response = state.inner_dav_handler.handle(req).await;
Ok(dav_response.into_response())
}
fn is_valid_authorization_header(headers: &axum::http::HeaderMap, should_password: &str) -> bool {
let auth_header_raw = match headers.get("Authorization") {
Some(authorization) => authorization,
None => return false,
};
let auth_header = match auth_header_raw.to_str() {
Ok(auth_header) => auth_header,
Err(_) => {
return false;
}
};
is_valid_authorization_header_str(auth_header, should_password)
}
fn is_valid_authorization_header_str(auth_header: &str, should_password: &str) -> bool {
if !auth_header.starts_with("Basic ") {
return false;
}
let base64_encoded = match auth_header.strip_prefix("Basic ") {
Some(encoded) => encoded,
None => return false,
};
let decoded = match base64::engine::general_purpose::STANDARD.decode(base64_encoded) {
Ok(decoded) => decoded,
Err(_) => return false,
};
let decoded_str = match String::from_utf8(decoded) {
Ok(str) => str,
Err(_) => return false,
};
let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
if parts.len() != 2 {
return false;
}
parts[0] == "admin" && parts[1] == should_password
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_authorization_header() {
let valid_auth = "Basic YWRtaW46cGFzc3dvcmQ="; assert!(
is_valid_authorization_header_str(valid_auth, "password"),
"Valid authorization header should be valid"
);
assert!(
!is_valid_authorization_header_str("NotBasic YWRtaW46cGFzc3dvcmQ=", "password"),
"Invalid format should be invalid"
);
assert!(
!is_valid_authorization_header_str("Basic", "password"),
"Invalid format should be invalid"
);
assert!(
!is_valid_authorization_header_str("Basic invalid-base64", "password"),
"Invalid base64 should be invalid"
);
let wrong_username = "Basic dXNlcjpwYXNzd29yZA=="; assert!(
!is_valid_authorization_header_str(wrong_username, "password"),
"Wrong username should be invalid"
);
let wrong_password = "Basic YWRtaW46d3JvbmctcGFzc3dvcmQ="; assert!(
!is_valid_authorization_header_str(wrong_password, "password"),
"Wrong password should be invalid"
);
let malformed = "Basic YWRtaW4="; assert!(
!is_valid_authorization_header_str(malformed, "password"),
"Malformed credentials should be invalid"
);
}
}