use crate::error::{AccessDenied, AccessDeniedHandler, DefaultDeniedHandler};
use crate::extractor::{
AuthExtractor, AuthResult, HeaderIdExtractor, HeaderRoleExtractor, IdExtractor, RoleExtractor,
};
use crate::rule::{AclAction, BitmaskAuth, RequestMeta};
use crate::table::AclTable;
use axum::extract::ConnectInfo;
use axum::response::Response;
use futures_util::future::BoxFuture;
use http::{Request, StatusCode};
use http_body::Body;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
pub struct AclConfig<E, I> {
pub table: Arc<AclTable>,
pub role_extractor: Arc<E>,
pub id_extractor: Arc<I>,
pub denied_handler: Arc<dyn AccessDeniedHandler>,
pub anonymous_roles: u32,
pub forwarded_ip_header: Option<String>,
pub default_id: String,
}
impl<E, I> Clone for AclConfig<E, I> {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
role_extractor: self.role_extractor.clone(),
id_extractor: self.id_extractor.clone(),
denied_handler: self.denied_handler.clone(),
anonymous_roles: self.anonymous_roles,
forwarded_ip_header: self.forwarded_ip_header.clone(),
default_id: self.default_id.clone(),
}
}
}
#[derive(Clone)]
pub struct AclLayer<E, I> {
config: AclConfig<E, I>,
}
impl AclLayer<HeaderRoleExtractor, HeaderIdExtractor> {
pub fn new(table: AclTable) -> Self {
Self {
config: AclConfig {
table: Arc::new(table),
role_extractor: Arc::new(HeaderRoleExtractor::new("X-Roles")),
id_extractor: Arc::new(HeaderIdExtractor::new("X-User-Id")),
denied_handler: Arc::new(DefaultDeniedHandler),
anonymous_roles: 0,
forwarded_ip_header: None,
default_id: "*".to_string(),
},
}
}
}
impl<E, I> AclLayer<E, I> {
pub fn with_role_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
AclLayer {
config: AclConfig {
table: self.config.table,
role_extractor: Arc::new(extractor),
id_extractor: self.config.id_extractor,
denied_handler: self.config.denied_handler,
anonymous_roles: self.config.anonymous_roles,
forwarded_ip_header: self.config.forwarded_ip_header,
default_id: self.config.default_id,
},
}
}
pub fn with_id_extractor<I2>(self, extractor: I2) -> AclLayer<E, I2> {
AclLayer {
config: AclConfig {
table: self.config.table,
role_extractor: self.config.role_extractor,
id_extractor: Arc::new(extractor),
denied_handler: self.config.denied_handler,
anonymous_roles: self.config.anonymous_roles,
forwarded_ip_header: self.config.forwarded_ip_header,
default_id: self.config.default_id,
},
}
}
#[deprecated(since = "0.2.0", note = "Use with_role_extractor instead")]
pub fn with_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
self.with_role_extractor(extractor)
}
pub fn with_denied_handler(mut self, handler: impl AccessDeniedHandler + 'static) -> Self {
self.config.denied_handler = Arc::new(handler);
self
}
pub fn with_anonymous_roles(mut self, roles: u32) -> Self {
self.config.anonymous_roles = roles;
self
}
pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
self.config.forwarded_ip_header = Some(header.into());
self
}
pub fn with_default_id(mut self, id: impl Into<String>) -> Self {
self.config.default_id = id.into();
self
}
pub fn table(&self) -> &AclTable {
&self.config.table
}
}
impl<S, E: Clone, I: Clone> Layer<S> for AclLayer<E, I> {
type Service = AclMiddleware<S, E, I>;
fn layer(&self, inner: S) -> Self::Service {
AclMiddleware {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct AclMiddleware<S, E, I> {
inner: S,
config: AclConfig<E, I>,
}
impl<S, E, I, ReqBody, ResBody> Service<Request<ReqBody>> for AclMiddleware<S, E, I>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
E: RoleExtractor<ReqBody> + 'static,
I: IdExtractor<ReqBody> + 'static,
ReqBody: Body + Send + 'static,
ResBody: Body + Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let config = self.config.clone();
let mut inner = self.inner.clone();
let role_result = config.role_extractor.extract_roles(&request);
let roles = role_result.roles_or(config.anonymous_roles);
let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
let id_result = config.id_extractor.extract_id(&request);
let id = id_result.id_or(&config.default_id);
let method = request.method().clone();
let path = request.uri().path().to_string();
Box::pin(async move {
let Some(client_ip) = client_ip else {
tracing::warn!("Failed to extract client IP address");
let response = Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(ResBody::default())
.unwrap();
return Ok(response);
};
let auth = BitmaskAuth {
roles,
id: id.clone(),
};
let meta = RequestMeta {
method,
path: path.clone(),
path_params: HashMap::new(),
ip: client_ip,
};
let action = config.table.evaluate_request(&auth, &meta);
handle_action(action, &path, &id, roles, client_ip, &config.denied_handler, request, &mut inner).await
})
}
}
pub struct GenericAclConfig<A, X> {
pub table: Arc<AclTable<A>>,
pub auth_extractor: Arc<X>,
pub denied_handler: Arc<dyn AccessDeniedHandler>,
pub forwarded_ip_header: Option<String>,
}
impl<A, X> Clone for GenericAclConfig<A, X> {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
auth_extractor: self.auth_extractor.clone(),
denied_handler: self.denied_handler.clone(),
forwarded_ip_header: self.forwarded_ip_header.clone(),
}
}
}
#[derive(Clone)]
pub struct GenericAclLayer<A, X> {
config: GenericAclConfig<A, X>,
}
impl<A, X> GenericAclLayer<A, X> {
pub fn with_auth(table: AclTable<A>, extractor: X) -> Self {
Self {
config: GenericAclConfig {
table: Arc::new(table),
auth_extractor: Arc::new(extractor),
denied_handler: Arc::new(DefaultDeniedHandler),
forwarded_ip_header: None,
},
}
}
pub fn with_denied_handler(
mut self,
handler: impl AccessDeniedHandler + 'static,
) -> Self {
self.config.denied_handler = Arc::new(handler);
self
}
pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
self.config.forwarded_ip_header = Some(header.into());
self
}
}
impl<S, A: Clone, X: Clone> Layer<S> for GenericAclLayer<A, X> {
type Service = GenericAclMiddleware<S, A, X>;
fn layer(&self, inner: S) -> Self::Service {
GenericAclMiddleware {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct GenericAclMiddleware<S, A, X> {
inner: S,
config: GenericAclConfig<A, X>,
}
impl<S, A, X, ReqBody, ResBody> Service<Request<ReqBody>> for GenericAclMiddleware<S, A, X>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
A: Send + Sync + 'static,
X: AuthExtractor<A, ReqBody> + 'static,
ReqBody: Body + Send + 'static,
ResBody: Body + Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let config = self.config.clone();
let mut inner = self.inner.clone();
let auth_result = config.auth_extractor.extract_auth(&request);
let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
let method = request.method().clone();
let path = request.uri().path().to_string();
Box::pin(async move {
let Some(client_ip) = client_ip else {
tracing::warn!("Failed to extract client IP address");
let response = Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(ResBody::default())
.unwrap();
return Ok(response);
};
let meta = RequestMeta {
method,
path: path.clone(),
path_params: HashMap::new(),
ip: client_ip,
};
let action = match auth_result {
AuthResult::Auth(auth) => config.table.evaluate_request(&auth, &meta),
AuthResult::Anonymous => config.table.default_action(),
AuthResult::Error(e) => {
tracing::warn!(error = %e, "Auth extraction failed");
AclAction::Deny
}
};
handle_action(action, &path, "*", 0, client_ip, &config.denied_handler, request, &mut inner).await
})
}
}
async fn handle_action<S, ReqBody, ResBody>(
action: AclAction,
path: &str,
id: &str,
roles: u32,
client_ip: IpAddr,
denied_handler: &Arc<dyn AccessDeniedHandler>,
request: Request<ReqBody>,
inner: &mut S,
) -> Result<Response<ResBody>, S::Error>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
ResBody: Body + Default + Send + 'static,
{
match action {
AclAction::Allow => {
tracing::trace!(
path = %path,
ip = %client_ip,
"ACL allowed request"
);
inner.call(request).await
}
AclAction::Deny => {
tracing::info!(
path = %path,
ip = %client_ip,
"ACL denied request"
);
let denied = AccessDenied::new_with_roles(roles, path, id);
let response = denied_handler.handle(&denied);
let (parts, _body) = response.into_parts();
let response = Response::from_parts(parts, ResBody::default());
Ok(response)
}
AclAction::Error { code, ref message } => {
tracing::info!(
path = %path,
ip = %client_ip,
code = code,
message = ?message,
"ACL returned error"
);
let status = StatusCode::from_u16(code).unwrap_or(StatusCode::FORBIDDEN);
let response = Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(ResBody::default())
.unwrap();
Ok(response)
}
AclAction::Reroute {
ref target,
preserve_path,
} => {
tracing::info!(
path = %path,
ip = %client_ip,
target = %target,
"ACL rerouting request"
);
let mut response = Response::builder()
.status(StatusCode::TEMPORARY_REDIRECT)
.header("location", target.as_str())
.body(ResBody::default())
.unwrap();
if preserve_path {
response.headers_mut().insert(
"x-original-path",
path.parse().unwrap_or_else(|_| "/".parse().unwrap()),
);
}
Ok(response)
}
AclAction::RateLimit {
max_requests,
window_secs,
} => {
tracing::warn!(
path = %path,
ip = %client_ip,
max_requests = max_requests,
window_secs = window_secs,
"ACL rate limit action - not implemented, allowing request"
);
inner.call(request).await
}
AclAction::Log {
ref level,
ref message,
} => {
let msg = message.clone().unwrap_or_else(|| {
format!("ACL log: path={}, ip={}", path, client_ip)
});
match level.as_str() {
"trace" => tracing::trace!("{}", msg),
"debug" => tracing::debug!("{}", msg),
"warn" => tracing::warn!("{}", msg),
"error" => tracing::error!("{}", msg),
_ => tracing::info!("{}", msg),
}
inner.call(request).await
}
}
}
fn extract_client_ip<B>(request: &Request<B>, forwarded_header: Option<&str>) -> Option<IpAddr> {
if let Some(header_name) = forwarded_header {
if let Some(value) = request.headers().get(header_name) {
if let Ok(s) = value.to_str() {
if let Some(first_ip) = s.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
}
request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip())
}
#[cfg(test)]
mod tests {
}