use std::collections::HashMap;
use serde_json::Value as JsonValue;
use crate::graphql::directive_evaluator::{
DirectiveError, DirectiveHandler, DirectiveResult, EvaluationContext,
};
pub struct RequirePermissionDirective;
impl RequirePermissionDirective {
#[must_use]
pub const fn new() -> Self {
Self
}
fn permission_matches(user_permission: &str, required_permission: &str) -> bool {
if user_permission == required_permission {
return true;
}
if user_permission == "*:*" {
return true;
}
let user_parts: Vec<&str> = user_permission.split(':').collect();
let required_parts: Vec<&str> = required_permission.split(':').collect();
if let Some(&last_part) = user_parts.last() {
if last_part == "*" {
let user_prefix_len = user_parts.len() - 1;
if user_prefix_len <= required_parts.len() {
return user_parts[..user_prefix_len] == required_parts[..user_prefix_len];
}
}
}
false
}
fn get_user_permissions(context: &EvaluationContext) -> Vec<String> {
context
.get_user_context("permissions")
.and_then(|v| v.as_array())
.map(|perms| perms.iter().filter_map(|p| p.as_str().map(String::from)).collect())
.unwrap_or_default()
}
fn user_has_permission(required_permission: &str, user_permissions: &[String]) -> bool {
user_permissions
.iter()
.any(|perm| Self::permission_matches(perm, required_permission))
}
}
impl Default for RequirePermissionDirective {
fn default() -> Self {
Self::new()
}
}
impl DirectiveHandler for RequirePermissionDirective {
fn name(&self) -> &'static str {
"require_permission"
}
fn evaluate(
&self,
args: &HashMap<String, JsonValue>,
context: &EvaluationContext,
) -> Result<DirectiveResult, DirectiveError> {
let required_permission = args
.get("permission")
.and_then(|v| v.as_str())
.ok_or_else(|| DirectiveError::MissingDirectiveArgument("permission".to_string()))?;
let user_permissions = Self::get_user_permissions(context);
if Self::user_has_permission(required_permission, &user_permissions) {
return Ok(DirectiveResult::Include);
}
if let Some(mask_value) = args.get("maskValue") {
return Ok(DirectiveResult::Transform(mask_value.clone()));
}
Ok(DirectiveResult::Error(format!(
"User lacks required permission: {}",
required_permission
)))
}
fn validate_args(&self, args: &HashMap<String, JsonValue>) -> Result<(), DirectiveError> {
if !args.contains_key("permission") {
return Err(DirectiveError::MissingDirectiveArgument("permission".to_string()));
}
let permission = args
.get("permission")
.and_then(|v| v.as_str())
.ok_or(DirectiveError::InvalidDirectiveArgument)?;
if permission.is_empty() {
return Err(DirectiveError::InvalidDirectiveArgument);
}
if let Some(mask) = args.get("maskValue") {
if !mask.is_string() && !mask.is_number() && !mask.is_null() {
return Err(DirectiveError::InvalidDirectiveArgument);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permission_matches_exact() {
assert!(RequirePermissionDirective::permission_matches(
"query:users:read",
"query:users:read"
));
assert!(!RequirePermissionDirective::permission_matches(
"query:users:read",
"query:users:write"
));
}
#[test]
fn test_permission_matches_wildcard() {
assert!(RequirePermissionDirective::permission_matches("*:*", "query:users:read"));
assert!(RequirePermissionDirective::permission_matches("query:*", "query:users:read"));
assert!(!RequirePermissionDirective::permission_matches(
"mutation:*",
"query:users:read"
));
}
}