#![cfg(feature = "axum")]
use axum::{body::Body, extract::Request, http::StatusCode, response::Response};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
use tracing::{debug, error, warn};
use super::auth::{AuthContext, PermissionScope};
fn build_error_response(status: StatusCode, message: &str) -> Response {
let sanitized_message = message
.chars()
.filter(|c| c.is_ascii() && !c.is_control())
.take(1000)
.collect::<String>();
let mut builder = Response::builder().status(status);
if status == StatusCode::UNAUTHORIZED {
builder = builder.header("WWW-Authenticate", "Bearer");
}
builder = builder.header("X-Auth-Error", sanitized_message);
builder
.body(Body::from(message.to_string()))
.unwrap_or_else(|_| {
Response::builder()
.status(status)
.body(Body::from(format!("Error: {status}")))
.unwrap_or_else(|_| Response::new(Body::from("Error")))
})
}
#[derive(Clone)]
pub struct PermissionLayer {
permission: String,
scope: PermissionScope,
}
impl PermissionLayer {
pub fn organization(permission: impl Into<String>) -> Self {
Self {
permission: permission.into(),
scope: PermissionScope::Organization,
}
}
pub fn workspace(permission: impl Into<String>) -> Self {
Self {
permission: permission.into(),
scope: PermissionScope::Workspace,
}
}
pub fn new(permission: impl Into<String>, scope: PermissionScope) -> Self {
Self {
permission: permission.into(),
scope,
}
}
}
impl<S> Layer<S> for PermissionLayer {
type Service = PermissionService<S>;
fn layer(&self, inner: S) -> Self::Service {
PermissionService {
inner,
permission: self.permission.clone(),
scope: self.scope.clone(),
}
}
}
#[derive(Clone)]
pub struct PermissionService<S> {
inner: S,
permission: String,
scope: PermissionScope,
}
impl<S> Service<Request<Body>> for PermissionService<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static + Clone,
S::Future: Send + 'static,
{
type Response = Response;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => {
error!("PermissionService: Inner service poll_ready returned error");
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let permission = self.permission.clone();
let scope = self.scope.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let auth_context = match req.extensions().get::<AuthContext>() {
Some(ctx) => ctx,
None => {
return Ok(build_error_response(
StatusCode::UNAUTHORIZED,
"No authentication context found",
));
}
};
let has_permission = match scope {
PermissionScope::Organization => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.organization
.as_ref()
.map(|perms| perms.contains(&permission))
.unwrap_or(false)
})
.unwrap_or(false),
PermissionScope::Workspace => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.workspace
.as_ref()
.map(|perms| perms.contains(&permission))
.unwrap_or(false)
})
.unwrap_or(false),
};
if !has_permission {
let error_msg = format!("Missing required permission: {permission}");
return Ok(build_error_response(StatusCode::FORBIDDEN, &error_msg));
}
match inner.call(req).await {
Ok(response) => Ok(response),
Err(_) => {
error!("PermissionService: Inner service call failed");
Ok(build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
))
}
}
})
}
}
#[derive(Clone)]
pub struct MultiplePermissionLayers {
permissions: Vec<(String, PermissionScope)>,
require_all: bool,
}
impl MultiplePermissionLayers {
pub fn all(permissions: Vec<(&str, PermissionScope)>) -> Self {
Self {
permissions: permissions
.into_iter()
.map(|(p, s)| (p.to_string(), s))
.collect(),
require_all: true,
}
}
pub fn any(permissions: Vec<(&str, PermissionScope)>) -> Self {
Self {
permissions: permissions
.into_iter()
.map(|(p, s)| (p.to_string(), s))
.collect(),
require_all: false,
}
}
}
impl<S> Layer<S> for MultiplePermissionLayers {
type Service = MultiplePermissionService<S>;
fn layer(&self, inner: S) -> Self::Service {
MultiplePermissionService {
inner,
permissions: self.permissions.clone(),
require_all: self.require_all,
}
}
}
#[derive(Clone)]
pub struct MultiplePermissionService<S> {
inner: S,
permissions: Vec<(String, PermissionScope)>,
require_all: bool,
}
impl<S> Service<Request<Body>> for MultiplePermissionService<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static + Clone,
S::Future: Send + 'static,
{
type Response = Response;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => {
error!("PermissionService: Inner service poll_ready returned error");
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let permissions = self.permissions.clone();
let require_all = self.require_all;
let mut inner = self.inner.clone();
Box::pin(async move {
let auth_context = match req.extensions().get::<AuthContext>() {
Some(ctx) => ctx,
None => {
return Ok(build_error_response(
StatusCode::UNAUTHORIZED,
"No authentication context found",
));
}
};
let check_permission = |permission: &str, scope: &PermissionScope| -> bool {
match scope {
PermissionScope::Organization => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.organization
.as_ref()
.map(|perms| perms.contains(&permission.to_string()))
.unwrap_or(false)
})
.unwrap_or(false),
PermissionScope::Workspace => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.workspace
.as_ref()
.map(|perms| perms.contains(&permission.to_string()))
.unwrap_or(false)
})
.unwrap_or(false),
}
};
let has_permission = if require_all {
permissions.iter().all(|(p, s)| check_permission(p, s))
} else {
permissions.iter().any(|(p, s)| check_permission(p, s))
};
if !has_permission {
let message = if require_all {
format!(
"Missing required permissions: {}",
permissions
.iter()
.map(|(p, _)| p.as_str())
.collect::<Vec<_>>()
.join(" AND ")
)
} else {
format!(
"Missing required permission: {}",
permissions
.iter()
.map(|(p, _)| p.as_str())
.collect::<Vec<_>>()
.join(" OR ")
)
};
return Ok(build_error_response(StatusCode::FORBIDDEN, &message));
}
match inner.call(req).await {
Ok(response) => Ok(response),
Err(_) => {
error!("PermissionService: Inner service call failed");
Ok(build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
))
}
}
})
}
}
#[derive(Clone)]
pub struct RequireAnyPermissionLayer {
permissions: Vec<(String, PermissionScope)>,
}
impl RequireAnyPermissionLayer {
pub fn new(permissions: Vec<(&str, PermissionScope)>) -> Self {
Self {
permissions: permissions
.into_iter()
.map(|(p, s)| (p.to_string(), s))
.collect(),
}
}
}
impl<S> Layer<S> for RequireAnyPermissionLayer {
type Service = RequireAnyPermissionService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireAnyPermissionService {
inner,
permissions: self.permissions.clone(),
}
}
}
#[derive(Clone)]
pub struct RequireAnyPermissionService<S> {
inner: S,
permissions: Vec<(String, PermissionScope)>,
}
impl<S> Service<Request<Body>> for RequireAnyPermissionService<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static + Clone,
S::Future: Send + 'static,
{
type Response = Response;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(_)) => {
error!("PermissionService: Inner service poll_ready returned error");
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let permissions = self.permissions.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let auth_context = match req.extensions().get::<AuthContext>() {
Some(ctx) => ctx,
None => {
return Ok(build_error_response(
StatusCode::UNAUTHORIZED,
"No authentication context found",
));
}
};
let has_any_permission = permissions.iter().any(|(permission, scope)| match scope {
PermissionScope::Organization => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.organization
.as_ref()
.map(|perms| perms.contains(permission))
.unwrap_or(false)
})
.unwrap_or(false),
PermissionScope::Workspace => auth_context
.permissions
.as_ref()
.map(|perms| {
perms
.workspace
.as_ref()
.map(|perms| perms.contains(permission))
.unwrap_or(false)
})
.unwrap_or(false),
});
if !has_any_permission {
let permission_list = permissions
.iter()
.map(|(p, _)| p.as_str())
.collect::<Vec<_>>()
.join(" OR ");
let error_msg = format!("Missing required permission: {permission_list}");
return Ok(build_error_response(StatusCode::FORBIDDEN, &error_msg));
}
match inner.call(req).await {
Ok(response) => Ok(response),
Err(_) => {
error!("PermissionService: Inner service call failed");
Ok(build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
))
}
}
})
}
}