use crate::server::OrdinaryAppServerState;
use axum::Router;
use axum::extract::{Request, State};
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::routing::MethodRouter;
use ordinary_config::{
MiddlewareConfig, MiddlewareMechanism, MiddlewareOperation, MiddlewareValidationComponent,
MiddlewareValidationRule, OrdinaryConfig,
};
use ordinary_utils::headers::{get_request_headers_for_forward, log_response};
use ordinary_utils::response::get_response_for_forwarded;
use ordinary_utils::{HeadersDebug, WrappedRedactedHashingAlg};
use std::sync::Arc;
use std::time::Instant;
use tower::ServiceBuilder;
use tracing::Span;
use url::Url;
#[allow(clippy::too_many_arguments)]
pub(crate) fn apply_custom_to_router<T>(
router: Router<T>,
config: &Arc<OrdinaryConfig>,
state: &Arc<OrdinaryAppServerState>,
names: &Vec<String>,
forwarding_domain: String,
forwarded_by: String,
forwarded_proto: String,
api_domain: Option<String>,
) -> Router<T>
where
T: Clone + Send + Sync + 'static,
{
let middleware = if let Some(middlewares) = config.get_middlewares(names) {
let middlewares = Arc::new(middlewares);
Some(axum::middleware::from_fn_with_state(
state.clone(),
move |State(state): State<Arc<OrdinaryAppServerState>>, req: Request, next: Next| {
run_custom(
state.clone(),
middlewares.clone(),
req,
next,
forwarded_by.clone(),
forwarded_proto.clone(),
state.log_headers,
state.redacted_hash.clone(),
forwarding_domain.clone(),
api_domain.clone(),
)
},
))
} else {
None
};
router.layer(ServiceBuilder::new().option_layer(middleware))
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn apply_custom_to_method_router<T>(
router: MethodRouter<T>,
config: &Arc<OrdinaryConfig>,
state: &Arc<OrdinaryAppServerState>,
names: &Vec<String>,
forwarding_domain: String,
forwarded_by: String,
forwarded_proto: String,
api_domain: Option<String>,
) -> MethodRouter<T>
where
T: Clone + Send + Sync + 'static,
{
let middleware = if let Some(middlewares) = config.get_middlewares(names) {
let middlewares = Arc::new(middlewares);
Some(axum::middleware::from_fn_with_state(
state.clone(),
move |State(state): State<Arc<OrdinaryAppServerState>>, req: Request, next: Next| {
run_custom(
state.clone(),
middlewares.clone(),
req,
next,
forwarded_by.clone(),
forwarded_proto.clone(),
state.log_headers,
state.redacted_hash.clone(),
forwarding_domain.clone(),
api_domain.clone(),
)
},
))
} else {
None
};
router.layer(ServiceBuilder::new().option_layer(middleware))
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn run_custom(
state: Arc<OrdinaryAppServerState>,
middlewares: Arc<Vec<MiddlewareConfig>>,
req: Request,
next: Next,
forwarded_by: String,
forwarded_proto: String,
log_headers: bool,
redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
forwarding_domain: String,
api_domain: Option<String>,
) -> Response {
let via_domain = api_domain.unwrap_or(forwarding_domain);
for middleware in middlewares.iter() {
match &middleware.mechanism {
MiddlewareMechanism::Request { endpoint } => match endpoint.parse() {
Ok(uri) => {
let start = Instant::now();
let mut out_req = reqwest::Request::new(req.method().clone(), uri);
match &middleware.operation {
MiddlewareOperation::Validate { rule, components } => {
for component in components {
match component {
MiddlewareValidationComponent::Headers => {
let headers = get_request_headers_for_forward(
&req,
forwarded_by.as_str(),
forwarded_proto.as_str(),
via_domain.as_str(),
);
*out_req.headers_mut() = headers;
}
MiddlewareValidationComponent::Path => {
if let Ok(url) = Url::parse(&req.uri().to_string())
&& let Some(segments) = url.path_segments()
{
for segment in segments {
if let Ok(mut segments) =
out_req.url_mut().path_segments_mut()
{
segments.push(segment);
}
}
}
}
MiddlewareValidationComponent::Query => {
if let Ok(url) = Url::parse(&req.uri().to_string()) {
for (name, value) in url.query_pairs() {
out_req
.url_mut()
.query_pairs_mut()
.append_pair(&name, &value);
}
}
}
}
}
let host = out_req.url().host().map(tracing::field::display);
let port = out_req.url().port().map(tracing::field::display);
let query = out_req.url().query().map(tracing::field::display);
let span = tracing::info_span!("mdw", nm = %middleware.name, host, port, path = %out_req.url().path(), query);
span.in_scope(|| {
let hd = log_headers.then_some(HeadersDebug(
out_req.headers(),
redacted_hash.clone(),
));
#[cfg(tracing_unstable)]
let headers = log_headers.then_some(tracing::field::valuable(&hd));
#[cfg(not(tracing_unstable))]
let headers = log_headers.then_some(tracing::field::debug(&hd));
tracing::info!(
version = ?out_req.version(),
method = %out_req.method(),
headers,
"req"
);
});
match state.reqwest_client.execute(out_req).await {
Ok(res) => {
let status = res.status().as_u16();
span.in_scope(|| {
log_response(
status,
log_headers,
&redacted_hash,
start,
res.headers(),
);
});
if !check_valid(middleware, &span, rule, status, false) {
return get_response_for_forwarded(
via_domain.as_str(),
res,
)
.into_response();
}
}
Err(err) => {
tracing::error!(%err);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
}
}
}
Err(err) => {
tracing::error!(%err);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
},
}
}
next.run(req).await
}
fn check_valid(
middleware: &MiddlewareConfig,
span: &Span,
rule: &MiddlewareValidationRule,
status: u16,
is_not: bool,
) -> bool {
match rule {
MiddlewareValidationRule::Any(rules) => {
for rule in rules {
if check_valid(middleware, span, rule, status, is_not) {
return true;
}
}
}
MiddlewareValidationRule::All(rules) => {
let mut is_valid = true;
for rule in rules {
if !check_valid(middleware, span, rule, status, is_not) {
is_valid = false;
}
}
return is_valid;
}
MiddlewareValidationRule::Not(rule) => {
if !check_valid(middleware, span, rule, status, true) {
return false;
}
}
MiddlewareValidationRule::StatusCode(code) => {
let is_valid = if is_not {
status != *code
} else {
status == *code
};
if is_valid {
span.in_scope(|| {
tracing::info!(
mdw.name = %middleware.name,
res.status = status,
rule.status = code,
"valid"
);
});
} else {
span.in_scope(|| {
tracing::warn!(
mdw.name = %middleware.name,
res.status = status,
rule.status = code,
"invalid"
);
});
return false;
}
}
}
true
}