use crate::{Authority, SecurityContext};
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub enum SecurityExpression {
HasRole(String),
HasAuthority(String),
IsAuthenticated,
IsAnonymous,
IsFullyAuthenticated,
HasPermission(String, String),
Custom(String),
}
impl SecurityExpression {
pub async fn evaluate(&self, context: &SecurityContext) -> bool {
match self {
SecurityExpression::IsAuthenticated => context.is_authenticated().await,
SecurityExpression::IsAnonymous => !context.is_authenticated().await,
SecurityExpression::IsFullyAuthenticated => {
context.is_authenticated().await
},
SecurityExpression::HasRole(role) => {
let role = crate::Role::from_str(role);
context.has_role(&role).await
},
SecurityExpression::HasAuthority(auth) => {
let authority = Authority::Permission(auth.clone());
context.has_authority(&authority).await
},
SecurityExpression::HasPermission(target, permission) => {
let auth = Authority::Permission(format!("{}:{}", target, permission));
context.has_authority(&auth).await
},
SecurityExpression::Custom(expr) => {
tracing::warn!("Custom security expression not implemented: {}", expr);
false
},
}
}
pub fn parse(input: &str) -> Vec<Self> {
let mut expressions = Vec::new();
if input.contains("hasRole(") {
if let Some(start) = input.find("hasRole('") {
if let Some(end) = input[start..].find("')") {
let role = &input[start + 9..start + end];
expressions.push(SecurityExpression::HasRole(role.to_string()));
}
} else if let Some(start) = input.find("hasRole(\"")
&& let Some(end) = input[start..].find("\")")
{
let role = &input[start + 9..start + end];
expressions.push(SecurityExpression::HasRole(role.to_string()));
}
}
if input.contains("hasAuthority(") {
if let Some(start) = input.find("hasAuthority('") {
if let Some(end) = input[start..].find("')") {
let auth = &input[start + 14..start + end];
expressions.push(SecurityExpression::HasAuthority(auth.to_string()));
}
} else if let Some(start) = input.find("hasAuthority(\"")
&& let Some(end) = input[start..].find("\")")
{
let auth = &input[start + 14..start + end];
expressions.push(SecurityExpression::HasAuthority(auth.to_string()));
}
}
if input.contains("isAuthenticated()") {
expressions.push(SecurityExpression::IsAuthenticated);
}
if input.contains("isAnonymous()") {
expressions.push(SecurityExpression::IsAnonymous);
}
if expressions.is_empty() {
expressions.push(SecurityExpression::Custom(input.to_string()));
}
expressions
}
}
pub trait PreAuthorize {
fn check_authorization(
&self,
context: &SecurityContext,
) -> Pin<Box<dyn Future<Output = bool> + Send>>;
}
#[derive(Debug, Clone)]
pub struct PreAuthorizeOptions {
pub expressions: Vec<SecurityExpression>,
pub require_all: bool,
}
impl PreAuthorizeOptions {
pub fn new() -> Self {
Self {
expressions: Vec::new(),
require_all: true,
}
}
pub fn add_expression(mut self, expr: SecurityExpression) -> Self {
self.expressions.push(expr);
self
}
pub fn add_expression_string(mut self, expr: impl Into<String>) -> Self {
let parsed = SecurityExpression::parse(&expr.into());
self.expressions.extend(parsed);
self
}
pub fn require_all(mut self, require_all: bool) -> Self {
self.require_all = require_all;
self
}
pub async fn evaluate(&self, context: &SecurityContext) -> bool {
if self.expressions.is_empty() {
return true;
}
if self.require_all {
for expr in &self.expressions {
if !expr.evaluate(context).await {
return false;
}
}
true
} else {
for expr in &self.expressions {
if expr.evaluate(context).await {
return true;
}
}
false
}
}
}
impl Default for PreAuthorizeOptions {
fn default() -> Self {
Self::new()
}
}
pub async fn check_pre_authorize(
context: &SecurityContext,
expression: &str,
) -> Result<bool, crate::SecurityError> {
let options = PreAuthorizeOptions::new().add_expression_string(expression);
Ok(options.evaluate(context).await)
}
pub struct Expressions;
impl Expressions {
pub fn has_role(role: impl Into<String>) -> SecurityExpression {
SecurityExpression::HasRole(role.into())
}
pub fn has_authority(auth: impl Into<String>) -> SecurityExpression {
SecurityExpression::HasAuthority(auth.into())
}
pub fn is_authenticated() -> SecurityExpression {
SecurityExpression::IsAuthenticated
}
pub fn is_anonymous() -> SecurityExpression {
SecurityExpression::IsAnonymous
}
pub fn has_permission(
target: impl Into<String>,
permission: impl Into<String>,
) -> SecurityExpression {
SecurityExpression::HasPermission(target.into(), permission.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_security_expression_parse() {
let exprs = SecurityExpression::parse("hasRole('ADMIN')");
assert_eq!(exprs.len(), 1);
match &exprs[0] {
SecurityExpression::HasRole(role) => assert_eq!(role, "ADMIN"),
_ => panic!("Expected HasRole"),
}
}
#[tokio::test]
async fn test_pre_authorize_options() {
let context = SecurityContext::new();
let options = PreAuthorizeOptions::new()
.add_expression(SecurityExpression::IsAuthenticated)
.add_expression_string("hasRole('ADMIN')");
assert!(!options.evaluate(&context).await);
}
}