use axum::{
body::{Body, Bytes},
extract::{Request, State},
middleware::Next,
response::{IntoResponse, Response},
};
use serde_json::{Value, json};
use std::{collections::HashMap, sync::Arc};
use url::form_urlencoded;
use crate::{
crypto::core::Crypto,
middlewares::{
ip::get_request_host,
models::{AUTHORIZATION, AuthModel, BASIC, BEARER, CACHE_AUTH_TOKEN, MiddlewareConfig},
},
response::error::{AppError, AppResult},
};
pub async fn interceptor(
config: State<Arc<MiddlewareConfig>>,
mut request: Request,
next: Next,
) -> Response {
let token_store = &config.token_store;
let ignore_urls = &config.ignore_urls;
let prefix = "";
let pms_ignore_urls = &config.pms_ignore_urls;
let auth_basics = &config.auth_basics;
let (request_ip, uri) = get_request_host(&mut request);
tracing::info!(
"Middleware interceptor - client_ip: {} uri: {:?}",
request_ip,
uri
);
if let Some(ignore_url) = ignore_urls
.iter()
.find(|ignore_url| uri.starts_with(ignore_url.as_str()))
{
tracing::info!("Middleware Authorization Ignore Urls :{}", ignore_url);
return next.run(request).await;
}
if let Some(ignore_url) = pms_ignore_urls
.iter()
.find(|ignore_url| uri.starts_with(ignore_url.as_str()))
{
let auth_str = request
.headers()
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.filter(|auth_str| auth_str.starts_with(BASIC))
.map(|auth_str| auth_str[(BASIC.len() + 1)..].trim());
tracing::info!(
"Middleware Authorization PMS Ignore Urls: {:?} auth_basics:{:?} auth_str:{:?}",
ignore_url,
auth_basics,
auth_str
);
if let Some(auth_str) = auth_str {
if let Some(matched_basic) = auth_basics.iter().find(|basic| basic.as_str() == auth_str)
{
let basic = Crypto::decode_basic_auth_key(matched_basic).map_err(|e| {
tracing::warn!(
"Middleware Authorization BASIC failed: auth_str:{:?} error{:?}",
auth_str,
e
);
AppError::Unauthorized.into_response()
});
tracing::info!(
"Middleware Authorization BASIC Success auth_str:{} basic:{:?}",
auth_str,
basic
);
} else {
tracing::warn!(
"Middleware Authorization BASIC not allowed auth_str:{:?}",
auth_str
);
return AppError::Unauthorized.into_response();
}
} else {
tracing::warn!("Middleware Missing or Invalid Authorization BASIC header");
return AppError::Unauthorized.into_response();
}
return next.run(request).await;
}
let mut token_opt: Option<String> = None;
if let Some(auth_header) = request.headers().get(AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with(BEARER) {
token_opt = Some(auth_str[(BEARER.len() + 1)..].to_string());
}
}
}
if token_opt.is_none() {
if let Some(query) = request.uri().query() {
let params: HashMap<_, _> = form_urlencoded::parse(query.as_bytes())
.into_owned()
.collect();
if let Some(token) = params.get("accessToken") {
token_opt = Some(token.clone());
}
}
}
if let Some(token) = token_opt {
let store_key = format!("{}{}{}", prefix, CACHE_AUTH_TOKEN, token);
let auth_model: AuthModel = match crate::middlewares::token_store::store_get::<AuthModel>(
token_store.as_ref(),
&store_key,
)
.await
{
Ok(Some(m)) => m,
Ok(None) => {
return {
tracing::warn!(
"Middleware token expired: store_key:{} token:{}",
store_key,
token
);
AppError::TokenExpired.into_response()
};
}
Err(e) => {
tracing::warn!("Middleware failed to fetch token from store: {}", e);
return AppError::TokenExpired.into_response();
}
};
tracing::warn!("Middleware extracted cache_token: {:?}", &auth_model);
let uid = auth_model.uid;
let tid = auth_model.tid;
let ouid = auth_model.ouid;
request.extensions_mut().insert(auth_model);
} else {
tracing::warn!(
"Middleware Missing Authorization BEARER header and accessToken query param"
);
return AppError::Unauthorized.into_response();
}
let body_bytes = match read_and_print_body(&mut request).await {
Ok(b) => b,
Err(e) => return e.into_response(),
};
let modified_bytes = match modify_body(body_bytes, &mut request).await {
Ok(b) => b,
Err(e) => return e.into_response(),
};
*request.body_mut() = Body::from(modified_bytes);
let response = next.run(request).await;
response
}
async fn read_and_print_body(request: &mut Request) -> AppResult<Bytes> {
let body = std::mem::replace(request.body_mut(), Body::empty());
let bytes = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|_| AppError::ClientError("Middleware Invalid request body".into()))?;
Ok(bytes)
}
async fn modify_body(bytes: Bytes, request: &mut Request) -> AppResult<Bytes> {
if bytes.is_empty() {
return Ok(bytes);
}
if let Ok(mut json) = serde_json::from_slice::<Value>(&bytes) {
match &mut json {
Value::Object(obj) => {
insert_auth_fields(obj, request);
}
Value::Array(arr) => {
for item in arr.iter_mut() {
if let Value::Object(obj) = item {
insert_auth_fields(obj, request);
}
}
}
_ => {
tracing::warn!("Middleware Interceptor json is not object or array");
}
}
let modified_bytes = serde_json::to_vec(&json)
.map_err(|_| AppError::Internal("Middleware Interceptor JSON encode error".into()))?;
return Ok(Bytes::from(modified_bytes));
} else {
tracing::warn!("Middleware Interceptor json parse failed");
}
Ok(bytes)
}
fn insert_auth_fields(obj: &mut serde_json::Map<String, Value>, request: &mut Request) {
match request.method().as_str() {
"POST" => {
if let Some(auth_model) = request.extensions().get::<AuthModel>() {
obj.insert("creator".to_string(), json!(auth_model.uid));
obj.insert("creator_by".to_string(), json!(auth_model.nickname));
obj.insert("updater".to_string(), json!(auth_model.uid));
obj.insert("updater_by".to_string(), json!(auth_model.nickname));
} else {
obj.insert("creator".to_string(), json!(0));
obj.insert("creator_by".to_string(), json!("anonymous"));
obj.insert("updater".to_string(), json!(0));
obj.insert("updater_by".to_string(), json!("anonymous"));
}
}
"PUT" => {
if let Some(auth_model) = request.extensions().get::<AuthModel>() {
obj.insert("updater".to_string(), json!(auth_model.uid));
obj.insert("updater_by".to_string(), json!(auth_model.nickname));
} else {
obj.insert("updater".to_string(), json!(0));
obj.insert("updater_by".to_string(), json!("anonymous"));
}
}
_ => {}
}
}