use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DirectiveLocation {
Object,
FieldDefinition,
Interface,
Union,
Enum,
EnumValue,
InputObject,
InputFieldDefinition,
Scalar,
ArgumentDefinition,
Query,
Mutation,
Subscription,
FragmentDefinition,
FragmentSpread,
InlineFragment,
VariableDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectiveArgument {
pub name: String,
pub value: DirectiveValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum DirectiveValue {
String(String),
Int(i64),
Float(f64),
Boolean(bool),
List(Vec<DirectiveValue>),
Object(HashMap<String, DirectiveValue>),
Null,
}
impl DirectiveValue {
pub fn as_string(&self) -> Option<&str> {
match self {
DirectiveValue::String(s) => Some(s),
_ => None,
}
}
pub fn as_int(&self) -> Option<i64> {
match self {
DirectiveValue::Int(i) => Some(*i),
_ => None,
}
}
pub fn as_float(&self) -> Option<f64> {
match self {
DirectiveValue::Float(f) => Some(*f),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
DirectiveValue::Boolean(b) => Some(*b),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectiveDefinition {
pub name: String,
pub description: Option<String>,
pub locations: Vec<DirectiveLocation>,
pub repeatable: bool,
pub arguments: Vec<DirectiveArgumentDefinition>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectiveArgumentDefinition {
pub name: String,
pub arg_type: String,
pub required: bool,
pub default_value: Option<DirectiveValue>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppliedDirective {
pub name: String,
pub arguments: Vec<DirectiveArgument>,
pub location: DirectiveLocation,
}
impl AppliedDirective {
pub fn new(name: String, location: DirectiveLocation) -> Self {
Self {
name,
arguments: Vec::new(),
location,
}
}
pub fn with_argument(mut self, name: String, value: DirectiveValue) -> Self {
self.arguments.push(DirectiveArgument { name, value });
self
}
pub fn get_argument(&self, name: &str) -> Option<&DirectiveValue> {
self.arguments
.iter()
.find(|arg| arg.name == name)
.map(|arg| &arg.value)
}
}
#[derive(Debug, Clone)]
pub struct DirectiveContext {
pub field_name: Option<String>,
pub type_name: Option<String>,
pub user_context: Option<HashMap<String, String>>,
pub field_value: Option<DirectiveValue>,
}
impl DirectiveContext {
pub fn new() -> Self {
Self {
field_name: None,
type_name: None,
user_context: None,
field_value: None,
}
}
pub fn with_field(mut self, field_name: String) -> Self {
self.field_name = Some(field_name);
self
}
pub fn with_type(mut self, type_name: String) -> Self {
self.type_name = Some(type_name);
self
}
pub fn with_user_context(mut self, user_context: HashMap<String, String>) -> Self {
self.user_context = Some(user_context);
self
}
pub fn with_value(mut self, value: DirectiveValue) -> Self {
self.field_value = Some(value);
self
}
}
impl Default for DirectiveContext {
fn default() -> Self {
Self::new()
}
}
pub trait DirectiveHandler: Send + Sync {
fn execute(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue>;
fn validate(&self, directive: &AppliedDirective) -> Result<()> {
let _ = directive;
Ok(())
}
}
pub struct AuthDirectiveHandler;
impl DirectiveHandler for AuthDirectiveHandler {
fn execute(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
if context.user_context.is_none() {
return Err(anyhow!("Authentication required"));
}
if let Some(requires) = directive.get_argument("requires") {
if let Some(required_role) = requires.as_string() {
let user_context = context
.user_context
.as_ref()
.expect("user_context should be set when authentication is present");
let user_role = user_context.get("role");
if user_role.map(|r| r.as_str()) != Some(required_role) {
return Err(anyhow!(
"Insufficient permissions. Required role: {}",
required_role
));
}
}
}
Ok(context.field_value.clone().unwrap_or(DirectiveValue::Null))
}
}
pub struct HasRoleDirectiveHandler;
impl DirectiveHandler for HasRoleDirectiveHandler {
fn execute(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
let required_role = directive
.get_argument("role")
.and_then(|v| v.as_string())
.ok_or_else(|| anyhow!("@hasRole directive requires 'role' argument"))?;
let user_context = context
.user_context
.as_ref()
.ok_or_else(|| anyhow!("User context not available"))?;
let user_roles = user_context
.get("roles")
.map(|r| r.split(',').collect::<Vec<_>>())
.unwrap_or_default();
if !user_roles.contains(&required_role) {
return Err(anyhow!(
"User does not have required role: {}",
required_role
));
}
Ok(context.field_value.clone().unwrap_or(DirectiveValue::Null))
}
}
pub struct UppercaseDirectiveHandler;
impl DirectiveHandler for UppercaseDirectiveHandler {
fn execute(
&self,
_directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
if let Some(DirectiveValue::String(s)) = &context.field_value {
Ok(DirectiveValue::String(s.to_uppercase()))
} else {
Ok(context.field_value.clone().unwrap_or(DirectiveValue::Null))
}
}
}
pub struct LowercaseDirectiveHandler;
impl DirectiveHandler for LowercaseDirectiveHandler {
fn execute(
&self,
_directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
if let Some(DirectiveValue::String(s)) = &context.field_value {
Ok(DirectiveValue::String(s.to_lowercase()))
} else {
Ok(context.field_value.clone().unwrap_or(DirectiveValue::Null))
}
}
}
pub struct ConstraintDirectiveHandler;
impl DirectiveHandler for ConstraintDirectiveHandler {
fn execute(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
let value = context
.field_value
.as_ref()
.ok_or_else(|| anyhow!("No value to validate"))?;
if let Some(val) = value.as_int() {
if let Some(min) = directive.get_argument("min").and_then(|v| v.as_int()) {
if val < min {
return Err(anyhow!("Value {} is less than minimum {}", val, min));
}
}
if let Some(max) = directive.get_argument("max").and_then(|v| v.as_int()) {
if val > max {
return Err(anyhow!("Value {} is greater than maximum {}", val, max));
}
}
}
if let Some(s) = value.as_string() {
if let Some(min_len) = directive.get_argument("minLength").and_then(|v| v.as_int()) {
if (s.len() as i64) < min_len {
return Err(anyhow!(
"String length {} is less than minimum {}",
s.len(),
min_len
));
}
}
if let Some(max_len) = directive.get_argument("maxLength").and_then(|v| v.as_int()) {
if (s.len() as i64) > max_len {
return Err(anyhow!(
"String length {} is greater than maximum {}",
s.len(),
max_len
));
}
}
}
Ok(value.clone())
}
}
pub struct CacheControlDirectiveHandler {
cache_hints: Arc<tokio::sync::RwLock<HashMap<String, CacheHint>>>,
}
#[derive(Debug, Clone)]
pub struct CacheHint {
pub max_age: u32,
pub scope: CacheScope,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheScope {
Public,
Private,
}
impl CacheControlDirectiveHandler {
pub fn new() -> Self {
Self {
cache_hints: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn get_cache_hint(&self, key: &str) -> Option<CacheHint> {
let hints = self.cache_hints.read().await;
hints.get(key).cloned()
}
}
impl DirectiveHandler for CacheControlDirectiveHandler {
fn execute(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
let max_age = directive
.get_argument("maxAge")
.and_then(|v| v.as_int())
.unwrap_or(0) as u32;
let scope = directive
.get_argument("scope")
.and_then(|v| v.as_string())
.map(|s| {
if s.eq_ignore_ascii_case("PRIVATE") {
CacheScope::Private
} else {
CacheScope::Public
}
})
.unwrap_or(CacheScope::Public);
let key = format!(
"{}:{}",
context.type_name.as_ref().unwrap_or(&"unknown".to_string()),
context
.field_name
.as_ref()
.unwrap_or(&"unknown".to_string())
);
let cache_hints = self.cache_hints.clone();
let key_clone = key.clone();
tokio::spawn(async move {
let mut hints = cache_hints.write().await;
hints.insert(key_clone, CacheHint { max_age, scope });
});
Ok(context.field_value.clone().unwrap_or(DirectiveValue::Null))
}
}
impl Default for CacheControlDirectiveHandler {
fn default() -> Self {
Self::new()
}
}
pub struct DirectiveRegistry {
definitions: Arc<tokio::sync::RwLock<HashMap<String, DirectiveDefinition>>>,
handlers: Arc<tokio::sync::RwLock<HashMap<String, Arc<dyn DirectiveHandler>>>>,
}
impl DirectiveRegistry {
pub fn new_empty() -> Self {
Self {
definitions: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
handlers: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn new() -> Self {
let registry = Self::new_empty();
let _ = registry.register_builtin_directives().await;
registry
}
async fn register_builtin_directives(&self) -> Result<()> {
self.register_directive(
DirectiveDefinition {
name: "auth".to_string(),
description: Some("Requires authentication and optional role".to_string()),
locations: vec![
DirectiveLocation::Object,
DirectiveLocation::FieldDefinition,
],
repeatable: false,
arguments: vec![DirectiveArgumentDefinition {
name: "requires".to_string(),
arg_type: "String".to_string(),
required: false,
default_value: None,
description: Some("Required role".to_string()),
}],
},
Arc::new(AuthDirectiveHandler),
)
.await?;
self.register_directive(
DirectiveDefinition {
name: "hasRole".to_string(),
description: Some("Requires specific role".to_string()),
locations: vec![DirectiveLocation::FieldDefinition],
repeatable: false,
arguments: vec![DirectiveArgumentDefinition {
name: "role".to_string(),
arg_type: "String!".to_string(),
required: true,
default_value: None,
description: Some("Required role name".to_string()),
}],
},
Arc::new(HasRoleDirectiveHandler),
)
.await?;
self.register_directive(
DirectiveDefinition {
name: "uppercase".to_string(),
description: Some("Converts string to uppercase".to_string()),
locations: vec![DirectiveLocation::FieldDefinition],
repeatable: false,
arguments: vec![],
},
Arc::new(UppercaseDirectiveHandler),
)
.await?;
self.register_directive(
DirectiveDefinition {
name: "lowercase".to_string(),
description: Some("Converts string to lowercase".to_string()),
locations: vec![DirectiveLocation::FieldDefinition],
repeatable: false,
arguments: vec![],
},
Arc::new(LowercaseDirectiveHandler),
)
.await?;
self.register_directive(
DirectiveDefinition {
name: "constraint".to_string(),
description: Some("Validates field value constraints".to_string()),
locations: vec![
DirectiveLocation::FieldDefinition,
DirectiveLocation::ArgumentDefinition,
],
repeatable: false,
arguments: vec![
DirectiveArgumentDefinition {
name: "min".to_string(),
arg_type: "Int".to_string(),
required: false,
default_value: None,
description: Some("Minimum value".to_string()),
},
DirectiveArgumentDefinition {
name: "max".to_string(),
arg_type: "Int".to_string(),
required: false,
default_value: None,
description: Some("Maximum value".to_string()),
},
DirectiveArgumentDefinition {
name: "minLength".to_string(),
arg_type: "Int".to_string(),
required: false,
default_value: None,
description: Some("Minimum string length".to_string()),
},
DirectiveArgumentDefinition {
name: "maxLength".to_string(),
arg_type: "Int".to_string(),
required: false,
default_value: None,
description: Some("Maximum string length".to_string()),
},
],
},
Arc::new(ConstraintDirectiveHandler),
)
.await?;
self.register_directive(
DirectiveDefinition {
name: "cacheControl".to_string(),
description: Some("Cache control hints".to_string()),
locations: vec![
DirectiveLocation::Object,
DirectiveLocation::FieldDefinition,
],
repeatable: false,
arguments: vec![
DirectiveArgumentDefinition {
name: "maxAge".to_string(),
arg_type: "Int".to_string(),
required: false,
default_value: Some(DirectiveValue::Int(0)),
description: Some("Max age in seconds".to_string()),
},
DirectiveArgumentDefinition {
name: "scope".to_string(),
arg_type: "String".to_string(),
required: false,
default_value: Some(DirectiveValue::String("PUBLIC".to_string())),
description: Some("Cache scope (PUBLIC or PRIVATE)".to_string()),
},
],
},
Arc::new(CacheControlDirectiveHandler::new()),
)
.await?;
Ok(())
}
pub async fn register_directive(
&self,
definition: DirectiveDefinition,
handler: Arc<dyn DirectiveHandler>,
) -> Result<()> {
let mut definitions = self.definitions.write().await;
let mut handlers = self.handlers.write().await;
definitions.insert(definition.name.clone(), definition.clone());
handlers.insert(definition.name, handler);
Ok(())
}
pub async fn get_definition(&self, name: &str) -> Option<DirectiveDefinition> {
let definitions = self.definitions.read().await;
definitions.get(name).cloned()
}
pub async fn execute_directive(
&self,
directive: &AppliedDirective,
context: &DirectiveContext,
) -> Result<DirectiveValue> {
let handlers = self.handlers.read().await;
if let Some(handler) = handlers.get(&directive.name) {
handler.execute(directive, context)
} else {
Err(anyhow!("Unknown directive: @{}", directive.name))
}
}
pub async fn validate_directive(&self, directive: &AppliedDirective) -> Result<()> {
let handlers = self.handlers.read().await;
if let Some(handler) = handlers.get(&directive.name) {
handler.validate(directive)
} else {
Err(anyhow!("Unknown directive: @{}", directive.name))
}
}
pub async fn get_directive_names(&self) -> Vec<String> {
let definitions = self.definitions.read().await;
definitions.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_directive_value_types() {
let string_val = DirectiveValue::String("test".to_string());
assert_eq!(string_val.as_string(), Some("test"));
let int_val = DirectiveValue::Int(42);
assert_eq!(int_val.as_int(), Some(42));
let bool_val = DirectiveValue::Boolean(true);
assert_eq!(bool_val.as_bool(), Some(true));
}
#[test]
fn test_applied_directive_creation() {
let directive =
AppliedDirective::new("auth".to_string(), DirectiveLocation::FieldDefinition)
.with_argument(
"requires".to_string(),
DirectiveValue::String("ADMIN".to_string()),
);
assert_eq!(directive.name, "auth");
assert_eq!(directive.arguments.len(), 1);
assert_eq!(
directive
.get_argument("requires")
.and_then(|v| v.as_string()),
Some("ADMIN")
);
}
#[test]
fn test_directive_context() {
let context = DirectiveContext::new()
.with_field("email".to_string())
.with_type("User".to_string());
assert_eq!(context.field_name, Some("email".to_string()));
assert_eq!(context.type_name, Some("User".to_string()));
}
#[tokio::test]
async fn test_directive_registry_creation() {
let registry = DirectiveRegistry::new().await;
let names = registry.get_directive_names().await;
assert!(names.contains(&"auth".to_string()));
assert!(names.contains(&"uppercase".to_string()));
assert!(names.contains(&"constraint".to_string()));
}
#[tokio::test]
async fn test_uppercase_directive() {
let handler = UppercaseDirectiveHandler;
let directive =
AppliedDirective::new("uppercase".to_string(), DirectiveLocation::FieldDefinition);
let context =
DirectiveContext::new().with_value(DirectiveValue::String("hello".to_string()));
let result = handler
.execute(&directive, &context)
.expect("should succeed");
assert_eq!(result.as_string(), Some("HELLO"));
}
#[tokio::test]
async fn test_lowercase_directive() {
let handler = LowercaseDirectiveHandler;
let directive =
AppliedDirective::new("lowercase".to_string(), DirectiveLocation::FieldDefinition);
let context =
DirectiveContext::new().with_value(DirectiveValue::String("HELLO".to_string()));
let result = handler
.execute(&directive, &context)
.expect("should succeed");
assert_eq!(result.as_string(), Some("hello"));
}
#[tokio::test]
async fn test_constraint_directive_min_max() {
let handler = ConstraintDirectiveHandler;
let directive =
AppliedDirective::new("constraint".to_string(), DirectiveLocation::FieldDefinition)
.with_argument("min".to_string(), DirectiveValue::Int(0))
.with_argument("max".to_string(), DirectiveValue::Int(100));
let context = DirectiveContext::new().with_value(DirectiveValue::Int(50));
assert!(handler.execute(&directive, &context).is_ok());
let context = DirectiveContext::new().with_value(DirectiveValue::Int(-5));
assert!(handler.execute(&directive, &context).is_err());
let context = DirectiveContext::new().with_value(DirectiveValue::Int(150));
assert!(handler.execute(&directive, &context).is_err());
}
#[tokio::test]
async fn test_auth_directive_no_context() {
let handler = AuthDirectiveHandler;
let directive =
AppliedDirective::new("auth".to_string(), DirectiveLocation::FieldDefinition);
let context = DirectiveContext::new();
let result = handler.execute(&directive, &context);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Authentication required"));
}
#[tokio::test]
async fn test_auth_directive_with_role() {
let handler = AuthDirectiveHandler;
let directive =
AppliedDirective::new("auth".to_string(), DirectiveLocation::FieldDefinition)
.with_argument(
"requires".to_string(),
DirectiveValue::String("ADMIN".to_string()),
);
let mut user_context = HashMap::new();
user_context.insert("role".to_string(), "USER".to_string());
let context = DirectiveContext::new().with_user_context(user_context);
let result = handler.execute(&directive, &context);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Insufficient permissions"));
}
#[tokio::test]
async fn test_has_role_directive() {
let handler = HasRoleDirectiveHandler;
let directive =
AppliedDirective::new("hasRole".to_string(), DirectiveLocation::FieldDefinition)
.with_argument(
"role".to_string(),
DirectiveValue::String("admin".to_string()),
);
let mut user_context = HashMap::new();
user_context.insert("roles".to_string(), "user,admin,moderator".to_string());
let context = DirectiveContext::new().with_user_context(user_context);
let result = handler.execute(&directive, &context);
assert!(result.is_ok());
}
}