use crate::types::{derive_account_id, JmapError, JmapErrorType, Principal};
use axum::{
extract::{Request, State},
http::{header, HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use base64::{engine::general_purpose, Engine as _};
use rusmes_auth::AuthBackend;
use rusmes_proto::Username;
use std::sync::Arc;
pub type SharedAuth = Arc<dyn AuthBackend>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Credentials {
Basic { username: String, password: String },
Bearer { token: String },
}
pub fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
let value = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let trimmed = value.trim();
if let Some(rest) = strip_scheme(trimmed, "Basic") {
let decoded_bytes = general_purpose::STANDARD.decode(rest).ok()?;
let decoded = String::from_utf8(decoded_bytes).ok()?;
let mut parts = decoded.splitn(2, ':');
let username = parts.next()?.to_string();
let password = parts.next()?.to_string();
if username.is_empty() {
return None;
}
return Some(Credentials::Basic { username, password });
}
if let Some(rest) = strip_scheme(trimmed, "Bearer") {
let token = rest.trim().to_string();
if token.is_empty() {
return None;
}
return Some(Credentials::Bearer { token });
}
None
}
fn strip_scheme<'a>(header_value: &'a str, scheme: &str) -> Option<&'a str> {
let scheme_len = scheme.len();
if header_value.len() <= scheme_len {
return None;
}
let (prefix, rest) = header_value.split_at(scheme_len);
if !prefix.eq_ignore_ascii_case(scheme) {
return None;
}
let rest = rest.trim_start();
if rest.is_empty() {
return None;
}
Some(rest)
}
pub async fn authenticate(
auth: &dyn AuthBackend,
creds: &Credentials,
) -> Result<Principal, AuthError> {
match creds {
Credentials::Basic { username, password } => {
let user = Username::new(username.clone()).map_err(|_| AuthError::Unauthorized)?;
let ok = auth
.authenticate(&user, password)
.await
.map_err(|err| AuthError::Backend(err.to_string()))?;
if !ok {
return Err(AuthError::Unauthorized);
}
Ok(Principal {
username: username.clone(),
account_id: derive_account_id(username),
scopes: Vec::new(),
})
}
Credentials::Bearer { token } => {
let username = auth
.verify_bearer_token(token)
.await
.map_err(|_| AuthError::Unauthorized)?;
let username_str = username.to_string();
Ok(Principal {
account_id: derive_account_id(&username_str),
username: username_str,
scopes: Vec::new(),
})
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthError {
Unauthorized,
Backend(String),
}
impl AuthError {
fn into_response_body(self) -> Response {
let detail = match self {
AuthError::Unauthorized => "Authentication required".to_string(),
AuthError::Backend(err) => {
tracing::warn!("JMAP auth backend error: {}", err);
"Authentication backend error".to_string()
}
};
let body = JmapError::new(JmapErrorType::ServerFail)
.with_status(401)
.with_detail(detail);
let mut resp = (StatusCode::UNAUTHORIZED, Json(body)).into_response();
if let Ok(value) = header::HeaderValue::from_str("Basic realm=\"jmap\"") {
resp.headers_mut().insert(header::WWW_AUTHENTICATE, value);
}
resp
}
}
pub async fn require_auth(
State(auth): State<SharedAuth>,
mut request: Request,
next: Next,
) -> Response {
let creds = match extract_credentials(request.headers()) {
Some(c) => c,
None => return AuthError::Unauthorized.into_response_body(),
};
let principal = match authenticate(auth.as_ref(), &creds).await {
Ok(p) => p,
Err(err) => return err.into_response_body(),
};
request.extensions_mut().insert(principal);
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use axum::http::HeaderValue;
struct TestBackend;
#[async_trait]
impl AuthBackend for TestBackend {
async fn authenticate(&self, username: &Username, password: &str) -> anyhow::Result<bool> {
Ok(username.as_str() == "alice" && password == "hunter2")
}
async fn verify_identity(&self, _username: &Username) -> anyhow::Result<bool> {
Ok(true)
}
async fn list_users(&self) -> anyhow::Result<Vec<Username>> {
Ok(vec![])
}
async fn create_user(&self, _u: &Username, _p: &str) -> anyhow::Result<()> {
Ok(())
}
async fn delete_user(&self, _u: &Username) -> anyhow::Result<()> {
Ok(())
}
async fn change_password(&self, _u: &Username, _p: &str) -> anyhow::Result<()> {
Ok(())
}
}
fn header_with_auth(value: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Ok(v) = HeaderValue::from_str(value) {
headers.insert(header::AUTHORIZATION, v);
}
headers
}
#[test]
fn test_extract_basic_ok() {
let headers = header_with_auth("Basic YWxpY2U6aHVudGVyMg==");
let creds = extract_credentials(&headers).expect("creds parse");
assert_eq!(
creds,
Credentials::Basic {
username: "alice".to_string(),
password: "hunter2".to_string()
}
);
}
#[test]
fn test_extract_basic_case_insensitive_scheme() {
let headers = header_with_auth("basic YWxpY2U6aHVudGVyMg==");
assert!(extract_credentials(&headers).is_some());
}
#[test]
fn test_extract_bearer_ok() {
let headers = header_with_auth("Bearer abc.def.ghi");
let creds = extract_credentials(&headers).expect("creds parse");
assert_eq!(
creds,
Credentials::Bearer {
token: "abc.def.ghi".to_string()
}
);
}
#[test]
fn test_extract_no_header() {
let headers = HeaderMap::new();
assert!(extract_credentials(&headers).is_none());
}
#[test]
fn test_extract_unknown_scheme() {
let headers = header_with_auth("Digest something");
assert!(extract_credentials(&headers).is_none());
}
#[test]
fn test_extract_basic_empty_username_rejected() {
let headers = header_with_auth("Basic OnB3ZA==");
assert!(extract_credentials(&headers).is_none());
}
#[test]
fn test_extract_basic_no_colon_rejected() {
let headers = header_with_auth("Basic YWxpY2VodW50ZXIy");
assert!(extract_credentials(&headers).is_none());
}
#[tokio::test]
async fn test_authenticate_basic_ok() {
let backend = TestBackend;
let creds = Credentials::Basic {
username: "alice".to_string(),
password: "hunter2".to_string(),
};
let principal = authenticate(&backend, &creds).await.expect("auth ok");
assert_eq!(principal.username, "alice");
assert_eq!(principal.account_id, "account-alice");
}
#[tokio::test]
async fn test_authenticate_basic_bad_password() {
let backend = TestBackend;
let creds = Credentials::Basic {
username: "alice".to_string(),
password: "wrong".to_string(),
};
let err = authenticate(&backend, &creds)
.await
.expect_err("should fail");
assert_eq!(err, AuthError::Unauthorized);
}
#[tokio::test]
async fn test_authenticate_bearer_backend_without_override_rejected() {
let backend = TestBackend;
let creds = Credentials::Bearer {
token: "anything".to_string(),
};
let err = authenticate(&backend, &creds)
.await
.expect_err("bearer 401");
assert_eq!(err, AuthError::Unauthorized);
}
#[tokio::test]
async fn test_authenticate_basic_with_email_username() {
let backend = TestBackend;
let creds = Credentials::Basic {
username: "bob@example.com".to_string(),
password: "hunter2".to_string(),
};
let err = authenticate(&backend, &creds).await.expect_err("rejected");
assert_eq!(err, AuthError::Unauthorized);
}
}