use std::{
fmt::{Debug, Display},
future::{Ready, ready},
};
use actix_web::{
Error as ActixError, FromRequest, HttpMessage, HttpRequest,
body::MessageBody,
dev::{Payload, ServiceRequest, ServiceResponse},
http::Method,
middleware::Next,
};
use async_trait::async_trait;
use glob::Pattern;
use tracing::{debug, info};
use crate::{
security::{AccessForbiddenError, Authorizer, Principal, RbacConfig},
server::ApiErrorV1,
};
use super::{Authenticating, AuthorizationError, Credentials};
pub async fn authentication_context_provider<A>(
req: ServiceRequest,
next: Next<impl MessageBody + 'static>,
) -> Result<ServiceResponse<impl MessageBody>, ActixError>
where
A: Authenticating<Credentials, Principal, AuthorizationError> + 'static,
{
if req.extensions().contains::<AuthenticationContext>() {
return next.call(req).await;
};
let credentials = Credentials::parse_credentials_from_request_headers(req.headers())
.map_err(ApiErrorV1::from)?;
let authenticator: &A = req.app_data::<A>().ok_or_else(|| {
ApiErrorV1::from(AuthorizationError(
"Could not find any configured authentication mechanism.".to_string(),
))
})?;
let principal = authenticator
.authenticate(&credentials)
.await
.map_err(ApiErrorV1::from)?;
let auth_ctx = AuthenticationContext {
principal,
method: req.method().to_owned(),
path: req.path().into(),
resource: req.match_name().map(|v| v.into()),
authentication_type: credentials.to_string(),
};
info!(
"Successfully authenticated using {}",
&auth_ctx.authentication_type
);
debug!("Credentials: {}", auth_ctx.principal);
req.extensions_mut().insert(auth_ctx);
next.call(req).await
}
#[derive(Clone, Debug)]
pub struct AuthenticationContext {
principal: Principal,
authentication_type: String,
method: Method,
path: String,
resource: Option<String>,
}
impl Display for AuthenticationContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AuthenticationContext {{ principal: {}, method: {}, path: {}, resource: {}, authentication_type: {} }}",
self.principal,
self.method,
self.path,
self.resource
.as_ref()
.map_or("<Could not map path to a resource>", |v| v),
self.authentication_type,
)
}
}
impl AuthenticationContext {
pub fn new(
principal: Principal,
method: Method,
path: &str,
resource: Option<&str>,
authentication_type: &str,
) -> Self {
Self {
principal,
method,
path: path.to_string(),
resource: resource.map(|v| v.to_string()),
authentication_type: authentication_type.to_string(),
}
}
pub fn method(&self) -> &str {
self.method.as_ref()
}
pub fn path(&self) -> &str {
&self.path
}
pub fn authentication_type(&self) -> &str {
&self.authentication_type
}
pub fn resource(&self) -> Option<&String> {
self.resource.as_ref()
}
pub fn principal(&self) -> &Principal {
&self.principal
}
}
impl FromRequest for AuthenticationContext {
type Error = ApiErrorV1;
type Future = Ready<Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ready(req.extensions().get().cloned().ok_or_else(|| {
AuthorizationError("Authentication context not found".to_string()).into()
}))
}
}
#[async_trait]
impl Authorizer<RbacConfig<String>, (), AccessForbiddenError> for AuthenticationContext {
async fn authorize(
&self,
authz_config: &RbacConfig<String>,
) -> Result<(), AccessForbiddenError> {
let resource = if let Some(resource) = self.resource() {
resource
} else {
return Err(AccessForbiddenError("Resource not found".to_string()));
};
let user_resource_access = self
.principal()
.attributes()
.get("groups")
.and_then(|res| serde_json::from_value::<Vec<String>>(res.clone()).ok())
.ok_or_else(|| {
AccessForbiddenError("Could not find group attribute in claims".to_string())
})?;
let access = authz_config
.roles()
.iter()
.filter_map(|(name, role)| {
role.accessible_resources()
.get(&self.method)
.iter()
.any(|idents| {
idents
.iter()
.filter_map(|ident| Pattern::new(ident).ok())
.any(|pattern| pattern.matches(resource))
})
.then_some(name)
})
.any(|name| user_resource_access.contains(name));
match access {
true => Ok(()),
false => Err(AccessForbiddenError(
"Not authorized to access resource".to_string(),
)),
}
}
}