1use std::{collections::HashMap, sync::Arc};
55
56use serde_json::Value as JsonValue;
57use thiserror::Error;
58
59use crate::graphql::types::{Directive, FieldSelection};
60
61#[derive(Debug, Error)]
63pub enum DirectiveError {
64 #[error("Missing directive argument: {0}")]
66 MissingDirectiveArgument(String),
67
68 #[error("Undefined variable: {0}")]
70 UndefinedVariable(String),
71
72 #[error("Variable type mismatch: {0} should be Boolean")]
74 VariableTypeMismatch(String),
75
76 #[error("Invalid directive argument")]
78 InvalidDirectiveArgument,
79
80 #[error("Custom directive error: {0}")]
82 CustomDirectiveError(String),
83
84 #[error("Unknown directive: @{0}")]
86 UnknownDirective(String),
87
88 #[error("Directive @{0} cannot be used at {1}")]
90 InvalidDirectiveLocation(String, String),
91}
92
93#[derive(Debug, Clone, Default, PartialEq)]
101pub enum DirectiveResult {
102 #[default]
104 Include,
105
106 Skip,
108
109 Transform(JsonValue),
112
113 Error(String),
115}
116
117#[derive(Debug, Clone, Default)]
122pub struct EvaluationContext {
123 pub variables: HashMap<String, JsonValue>,
125
126 pub user_context: HashMap<String, JsonValue>,
129
130 pub field_path: Option<String>,
132
133 pub operation_type: Option<OperationType>,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum OperationType {
140 Query,
142 Mutation,
144 Subscription,
146}
147
148impl EvaluationContext {
149 #[must_use]
151 pub fn new(variables: HashMap<String, JsonValue>) -> Self {
152 Self {
153 variables,
154 ..Default::default()
155 }
156 }
157
158 #[must_use]
160 pub fn with_user_context(mut self, key: impl Into<String>, value: JsonValue) -> Self {
161 self.user_context.insert(key.into(), value);
162 self
163 }
164
165 #[must_use]
167 pub fn with_field_path(mut self, path: impl Into<String>) -> Self {
168 self.field_path = Some(path.into());
169 self
170 }
171
172 #[must_use]
174 pub fn with_operation_type(mut self, op_type: OperationType) -> Self {
175 self.operation_type = Some(op_type);
176 self
177 }
178
179 #[must_use]
181 pub fn get_user_context(&self, key: &str) -> Option<&JsonValue> {
182 self.user_context.get(key)
183 }
184
185 #[must_use]
189 pub fn has_role(&self, role: &str) -> bool {
190 self.user_context
191 .get("roles")
192 .and_then(|v| v.as_array())
193 .is_some_and(|roles| roles.iter().any(|r| r.as_str() == Some(role)))
194 }
195
196 #[must_use]
198 pub fn user_id(&self) -> Option<&str> {
199 self.user_context.get("userId").and_then(|v| v.as_str())
200 }
201}
202
203pub trait DirectiveHandler: Send + Sync {
238 fn name(&self) -> &str;
240
241 fn evaluate(
252 &self,
253 args: &HashMap<String, JsonValue>,
254 context: &EvaluationContext,
255 ) -> Result<DirectiveResult, DirectiveError>;
256
257 fn validate_args(&self, _args: &HashMap<String, JsonValue>) -> Result<(), DirectiveError> {
264 Ok(())
265 }
266}
267
268pub struct DirectiveEvaluator;
300
301impl DirectiveEvaluator {
302 pub fn evaluate_directives(
313 selection: &FieldSelection,
314 variables: &HashMap<String, JsonValue>,
315 ) -> Result<bool, DirectiveError> {
316 if selection.directives.is_empty() {
318 return Ok(true);
319 }
320
321 for directive in &selection.directives {
323 match directive.name.as_str() {
324 "skip" => {
325 if Self::evaluate_skip(directive, variables)? {
327 return Ok(false); }
329 },
330 "include" => {
331 if !Self::evaluate_include(directive, variables)? {
333 return Ok(false); }
335 },
336 _ => {
337 tracing::warn!("Unknown directive @{}", directive.name);
340 },
341 }
342 }
343
344 Ok(true)
346 }
347
348 pub fn evaluate_skip(
352 directive: &Directive,
353 variables: &HashMap<String, JsonValue>,
354 ) -> Result<bool, DirectiveError> {
355 let if_arg = directive
356 .arguments
357 .iter()
358 .find(|a| a.name == "if")
359 .ok_or(DirectiveError::MissingDirectiveArgument("if".to_string()))?;
360
361 Self::resolve_boolean_condition(&if_arg.value_json, variables)
362 }
363
364 pub fn evaluate_include(
368 directive: &Directive,
369 variables: &HashMap<String, JsonValue>,
370 ) -> Result<bool, DirectiveError> {
371 let if_arg = directive
372 .arguments
373 .iter()
374 .find(|a| a.name == "if")
375 .ok_or(DirectiveError::MissingDirectiveArgument("if".to_string()))?;
376
377 Self::resolve_boolean_condition(&if_arg.value_json, variables)
378 }
379
380 fn resolve_boolean_condition(
386 value_json: &str,
387 variables: &HashMap<String, JsonValue>,
388 ) -> Result<bool, DirectiveError> {
389 match serde_json::from_str::<JsonValue>(value_json) {
391 Ok(JsonValue::Bool(b)) => Ok(b),
392 Ok(JsonValue::String(s)) if s.starts_with('$') => {
393 let var_name = &s[1..]; let val = variables
396 .get(var_name)
397 .ok_or_else(|| DirectiveError::UndefinedVariable(var_name.to_string()))?;
398
399 match val {
400 JsonValue::Bool(b) => Ok(*b),
401 _ => Err(DirectiveError::VariableTypeMismatch(var_name.to_string())),
402 }
403 },
404 Ok(_) => Err(DirectiveError::InvalidDirectiveArgument),
405 Err(_) => {
406 if let Some(var_name) = value_json.strip_prefix('$') {
408 let val = variables
409 .get(var_name)
410 .ok_or_else(|| DirectiveError::UndefinedVariable(var_name.to_string()))?;
411
412 match val {
413 JsonValue::Bool(b) => Ok(*b),
414 _ => Err(DirectiveError::VariableTypeMismatch(var_name.to_string())),
415 }
416 } else {
417 Err(DirectiveError::InvalidDirectiveArgument)
418 }
419 },
420 }
421 }
422
423 pub fn filter_selections(
431 selections: &[FieldSelection],
432 variables: &HashMap<String, JsonValue>,
433 ) -> Result<Vec<FieldSelection>, DirectiveError> {
434 let mut result = Vec::new();
435
436 for selection in selections {
437 if Self::evaluate_directives(selection, variables)? {
438 let mut field = selection.clone();
439
440 if !field.nested_fields.is_empty() {
442 field.nested_fields = Self::filter_selections(&field.nested_fields, variables)?;
443 }
444
445 result.push(field);
446 }
447 }
448
449 Ok(result)
450 }
451
452 pub fn parse_directive_args(
456 directive: &Directive,
457 variables: &HashMap<String, JsonValue>,
458 ) -> Result<HashMap<String, JsonValue>, DirectiveError> {
459 let mut args = HashMap::new();
460
461 for arg in &directive.arguments {
462 let value = Self::resolve_argument_value(&arg.value_json, variables)?;
463 args.insert(arg.name.clone(), value);
464 }
465
466 Ok(args)
467 }
468
469 fn resolve_argument_value(
471 value_json: &str,
472 variables: &HashMap<String, JsonValue>,
473 ) -> Result<JsonValue, DirectiveError> {
474 match serde_json::from_str::<JsonValue>(value_json) {
476 Ok(JsonValue::String(s)) if s.starts_with('$') => {
477 let var_name = &s[1..];
479 variables
480 .get(var_name)
481 .cloned()
482 .ok_or_else(|| DirectiveError::UndefinedVariable(var_name.to_string()))
483 },
484 Ok(value) => Ok(value),
485 Err(_) => {
486 if let Some(var_name) = value_json.strip_prefix('$') {
488 variables
489 .get(var_name)
490 .cloned()
491 .ok_or_else(|| DirectiveError::UndefinedVariable(var_name.to_string()))
492 } else {
493 Ok(JsonValue::String(value_json.to_string()))
495 }
496 },
497 }
498 }
499}
500
501#[derive(Clone)]
543pub struct CustomDirectiveEvaluator {
544 handlers: HashMap<String, Arc<dyn DirectiveHandler>>,
546
547 strict_mode: bool,
550}
551
552impl Default for CustomDirectiveEvaluator {
553 fn default() -> Self {
554 Self::new()
555 }
556}
557
558impl CustomDirectiveEvaluator {
559 #[must_use]
561 pub fn new() -> Self {
562 Self {
563 handlers: HashMap::new(),
564 strict_mode: false,
565 }
566 }
567
568 #[must_use]
570 pub fn strict(mut self) -> Self {
571 self.strict_mode = true;
572 self
573 }
574
575 #[must_use]
577 pub fn with_handler(mut self, handler: Arc<dyn DirectiveHandler>) -> Self {
578 let name = handler.name().to_string();
579 self.handlers.insert(name, handler);
580 self
581 }
582
583 #[must_use]
585 pub fn with_handlers(mut self, handlers: Vec<Arc<dyn DirectiveHandler>>) -> Self {
586 for handler in handlers {
587 let name = handler.name().to_string();
588 self.handlers.insert(name, handler);
589 }
590 self
591 }
592
593 #[must_use]
595 pub fn has_handler(&self, name: &str) -> bool {
596 self.handlers.contains_key(name)
597 }
598
599 #[must_use]
601 pub fn get_handler(&self, name: &str) -> Option<&Arc<dyn DirectiveHandler>> {
602 self.handlers.get(name)
603 }
604
605 #[must_use]
607 pub fn handler_names(&self) -> Vec<&str> {
608 self.handlers.keys().map(String::as_str).collect()
609 }
610
611 pub fn evaluate_directives_with_context(
618 &self,
619 selection: &FieldSelection,
620 context: &EvaluationContext,
621 ) -> Result<DirectiveResult, DirectiveError> {
622 if selection.directives.is_empty() {
623 return Ok(DirectiveResult::Include);
624 }
625
626 for directive in &selection.directives {
627 let result = self.evaluate_single_directive(directive, context)?;
628
629 match result {
630 DirectiveResult::Include => {},
631 DirectiveResult::Skip => return Ok(DirectiveResult::Skip),
632 DirectiveResult::Transform(_) | DirectiveResult::Error(_) => return Ok(result),
633 }
634 }
635
636 Ok(DirectiveResult::Include)
637 }
638
639 fn evaluate_single_directive(
641 &self,
642 directive: &Directive,
643 context: &EvaluationContext,
644 ) -> Result<DirectiveResult, DirectiveError> {
645 match directive.name.as_str() {
646 "skip" => {
648 if DirectiveEvaluator::evaluate_skip(directive, &context.variables)? {
649 Ok(DirectiveResult::Skip)
650 } else {
651 Ok(DirectiveResult::Include)
652 }
653 },
654 "include" => {
655 if DirectiveEvaluator::evaluate_include(directive, &context.variables)? {
656 Ok(DirectiveResult::Include)
657 } else {
658 Ok(DirectiveResult::Skip)
659 }
660 },
661 "deprecated" => {
662 Ok(DirectiveResult::Include)
665 },
666 name => {
668 if let Some(handler) = self.handlers.get(name) {
669 let args =
670 DirectiveEvaluator::parse_directive_args(directive, &context.variables)?;
671 handler.evaluate(&args, context)
672 } else if self.strict_mode {
673 Err(DirectiveError::UnknownDirective(name.to_string()))
674 } else {
675 tracing::warn!("Unknown directive @{}, passing through", name);
676 Ok(DirectiveResult::Include)
677 }
678 },
679 }
680 }
681
682 pub fn filter_selections_with_context(
687 &self,
688 selections: &[FieldSelection],
689 context: &EvaluationContext,
690 ) -> Result<Vec<FieldSelection>, DirectiveError> {
691 let mut result = Vec::new();
692
693 for selection in selections {
694 let directive_result = self.evaluate_directives_with_context(selection, context)?;
695
696 match directive_result {
697 DirectiveResult::Include | DirectiveResult::Transform(_) => {
698 let mut field = selection.clone();
699
700 if !field.nested_fields.is_empty() {
702 field.nested_fields =
703 self.filter_selections_with_context(&field.nested_fields, context)?;
704 }
705
706 result.push(field);
707 },
708 DirectiveResult::Skip => {
709 },
711 DirectiveResult::Error(msg) => {
712 return Err(DirectiveError::CustomDirectiveError(msg));
713 },
714 }
715 }
716
717 Ok(result)
718 }
719}
720
721impl std::fmt::Debug for CustomDirectiveEvaluator {
722 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
723 f.debug_struct("CustomDirectiveEvaluator")
724 .field("handlers", &self.handlers.keys().collect::<Vec<_>>())
725 .field("strict_mode", &self.strict_mode)
726 .finish()
727 }
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733 use crate::graphql::types::GraphQLArgument;
734
735 fn make_field(name: &str, directives: Vec<Directive>) -> FieldSelection {
736 FieldSelection {
737 name: name.to_string(),
738 alias: None,
739 arguments: vec![],
740 nested_fields: vec![],
741 directives,
742 }
743 }
744
745 fn make_directive(name: &str, if_value: &str) -> Directive {
746 Directive {
747 name: name.to_string(),
748 arguments: vec![GraphQLArgument {
749 name: "if".to_string(),
750 value_type: "boolean".to_string(),
751 value_json: if_value.to_string(),
752 }],
753 }
754 }
755
756 #[test]
757 fn test_field_without_directives() {
758 let field = make_field("email", vec![]);
759 let variables = HashMap::new();
760
761 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
762 assert!(result);
763 }
764
765 #[test]
766 fn test_skip_with_true_literal() {
767 let field = make_field("email", vec![make_directive("skip", "true")]);
768 let variables = HashMap::new();
769
770 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
771 assert!(!result); }
773
774 #[test]
775 fn test_skip_with_false_literal() {
776 let field = make_field("email", vec![make_directive("skip", "false")]);
777 let variables = HashMap::new();
778
779 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
780 assert!(result); }
782
783 #[test]
784 fn test_include_with_true_literal() {
785 let field = make_field("email", vec![make_directive("include", "true")]);
786 let variables = HashMap::new();
787
788 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
789 assert!(result); }
791
792 #[test]
793 fn test_include_with_false_literal() {
794 let field = make_field("email", vec![make_directive("include", "false")]);
795 let variables = HashMap::new();
796
797 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
798 assert!(!result); }
800
801 #[test]
802 fn test_skip_with_variable() {
803 let field = make_field("email", vec![make_directive("skip", "\"$skipEmail\"")]);
804 let mut variables = HashMap::new();
805 variables.insert("skipEmail".to_string(), JsonValue::Bool(true));
806
807 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
808 assert!(!result); }
810
811 #[test]
812 fn test_include_with_variable() {
813 let field = make_field("email", vec![make_directive("include", "\"$includeEmail\"")]);
814 let mut variables = HashMap::new();
815 variables.insert("includeEmail".to_string(), JsonValue::Bool(false));
816
817 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
818 assert!(!result); }
820
821 #[test]
822 fn test_undefined_variable() {
823 let field = make_field("email", vec![make_directive("skip", "\"$undefined\"")]);
824 let variables = HashMap::new();
825
826 let result = DirectiveEvaluator::evaluate_directives(&field, &variables);
827 assert!(matches!(result, Err(DirectiveError::UndefinedVariable(_))));
828 }
829
830 #[test]
831 fn test_multiple_directives() {
832 let directives = vec![
834 make_directive("skip", "false"), make_directive("include", "true"), ];
837 let field = make_field("email", directives);
838 let variables = HashMap::new();
839
840 let result = DirectiveEvaluator::evaluate_directives(&field, &variables).unwrap();
841 assert!(result); }
843
844 #[test]
845 fn test_variable_type_mismatch() {
846 let field = make_field("email", vec![make_directive("skip", "\"$notABool\"")]);
847 let mut variables = HashMap::new();
848 variables.insert("notABool".to_string(), JsonValue::String("hello".to_string()));
849
850 let result = DirectiveEvaluator::evaluate_directives(&field, &variables);
851 assert!(matches!(result, Err(DirectiveError::VariableTypeMismatch(_))));
852 }
853
854 #[test]
855 fn test_filter_selections() {
856 let selections = vec![
857 make_field("id", vec![]),
858 make_field("email", vec![make_directive("skip", "true")]),
859 make_field("name", vec![make_directive("include", "true")]),
860 ];
861
862 let variables = HashMap::new();
863 let filtered = DirectiveEvaluator::filter_selections(&selections, &variables).unwrap();
864
865 assert_eq!(filtered.len(), 2);
866 assert_eq!(filtered[0].name, "id");
867 assert_eq!(filtered[1].name, "name");
868 }
869
870 #[test]
871 fn test_filter_nested_selections() {
872 let selections = vec![FieldSelection {
873 name: "user".to_string(),
874 alias: None,
875 arguments: vec![],
876 nested_fields: vec![
877 make_field("id", vec![]),
878 make_field("secret", vec![make_directive("skip", "true")]),
879 ],
880 directives: vec![],
881 }];
882
883 let variables = HashMap::new();
884 let filtered = DirectiveEvaluator::filter_selections(&selections, &variables).unwrap();
885
886 assert_eq!(filtered.len(), 1);
887 assert_eq!(filtered[0].nested_fields.len(), 1);
888 assert_eq!(filtered[0].nested_fields[0].name, "id");
889 }
890
891 struct AuthDirective {
897 required_role: String,
898 }
899
900 #[allow(clippy::unnecessary_literal_bound)] impl DirectiveHandler for AuthDirective {
902 fn name(&self) -> &str {
903 "auth"
904 }
905
906 fn evaluate(
907 &self,
908 args: &HashMap<String, JsonValue>,
909 context: &EvaluationContext,
910 ) -> Result<DirectiveResult, DirectiveError> {
911 let required = args.get("role").and_then(|v| v.as_str()).unwrap_or(&self.required_role);
913
914 if context.has_role(required) {
915 Ok(DirectiveResult::Include)
916 } else {
917 Ok(DirectiveResult::Skip)
918 }
919 }
920 }
921
922 struct AlwaysSkipDirective;
924
925 #[allow(clippy::unnecessary_literal_bound)] impl DirectiveHandler for AlwaysSkipDirective {
927 fn name(&self) -> &str {
928 "alwaysSkip"
929 }
930
931 fn evaluate(
932 &self,
933 _args: &HashMap<String, JsonValue>,
934 _context: &EvaluationContext,
935 ) -> Result<DirectiveResult, DirectiveError> {
936 Ok(DirectiveResult::Skip)
937 }
938 }
939
940 struct ErrorDirective;
942
943 #[allow(clippy::unnecessary_literal_bound)] impl DirectiveHandler for ErrorDirective {
945 fn name(&self) -> &str {
946 "error"
947 }
948
949 fn evaluate(
950 &self,
951 _args: &HashMap<String, JsonValue>,
952 _context: &EvaluationContext,
953 ) -> Result<DirectiveResult, DirectiveError> {
954 Ok(DirectiveResult::Error("Test error".to_string()))
955 }
956 }
957
958 #[test]
959 fn test_custom_directive_evaluator_creation() {
960 let evaluator = CustomDirectiveEvaluator::new();
961 assert!(!evaluator.has_handler("auth"));
962 assert!(evaluator.handler_names().is_empty());
963 }
964
965 #[test]
966 fn test_custom_directive_handler_registration() {
967 let auth = Arc::new(AuthDirective {
968 required_role: "admin".to_string(),
969 });
970 let evaluator = CustomDirectiveEvaluator::new().with_handler(auth);
971
972 assert!(evaluator.has_handler("auth"));
973 assert!(!evaluator.has_handler("unknown"));
974 assert_eq!(evaluator.handler_names(), vec!["auth"]);
975 }
976
977 #[test]
978 fn test_custom_directive_with_context() {
979 let auth = Arc::new(AuthDirective {
980 required_role: "admin".to_string(),
981 });
982 let evaluator = CustomDirectiveEvaluator::new().with_handler(auth);
983
984 let context = EvaluationContext::new(HashMap::new()).with_user_context(
986 "roles",
987 JsonValue::Array(vec![JsonValue::String("admin".to_string())]),
988 );
989
990 let field = FieldSelection {
992 name: "sensitiveData".to_string(),
993 alias: None,
994 arguments: vec![],
995 nested_fields: vec![],
996 directives: vec![Directive {
997 name: "auth".to_string(),
998 arguments: vec![GraphQLArgument {
999 name: "role".to_string(),
1000 value_type: "String".to_string(),
1001 value_json: "\"admin\"".to_string(),
1002 }],
1003 }],
1004 };
1005
1006 let result = evaluator.evaluate_directives_with_context(&field, &context).unwrap();
1007 assert_eq!(result, DirectiveResult::Include);
1008 }
1009
1010 #[test]
1011 fn test_custom_directive_denies_without_role() {
1012 let auth = Arc::new(AuthDirective {
1013 required_role: "admin".to_string(),
1014 });
1015 let evaluator = CustomDirectiveEvaluator::new().with_handler(auth);
1016
1017 let context = EvaluationContext::new(HashMap::new()).with_user_context(
1019 "roles",
1020 JsonValue::Array(vec![JsonValue::String("user".to_string())]),
1021 );
1022
1023 let field = FieldSelection {
1025 name: "sensitiveData".to_string(),
1026 alias: None,
1027 arguments: vec![],
1028 nested_fields: vec![],
1029 directives: vec![Directive {
1030 name: "auth".to_string(),
1031 arguments: vec![GraphQLArgument {
1032 name: "role".to_string(),
1033 value_type: "String".to_string(),
1034 value_json: "\"admin\"".to_string(),
1035 }],
1036 }],
1037 };
1038
1039 let result = evaluator.evaluate_directives_with_context(&field, &context).unwrap();
1040 assert_eq!(result, DirectiveResult::Skip);
1041 }
1042
1043 #[test]
1044 fn test_custom_directive_strict_mode_unknown() {
1045 let evaluator = CustomDirectiveEvaluator::new().strict();
1046
1047 let context = EvaluationContext::new(HashMap::new());
1048 let field = make_field(
1049 "email",
1050 vec![Directive {
1051 name: "unknown".to_string(),
1052 arguments: vec![],
1053 }],
1054 );
1055
1056 let result = evaluator.evaluate_directives_with_context(&field, &context);
1057 assert!(matches!(result, Err(DirectiveError::UnknownDirective(_))));
1058 }
1059
1060 #[test]
1061 fn test_custom_directive_lenient_mode_unknown() {
1062 let evaluator = CustomDirectiveEvaluator::new();
1063
1064 let context = EvaluationContext::new(HashMap::new());
1065 let field = make_field(
1066 "email",
1067 vec![Directive {
1068 name: "unknown".to_string(),
1069 arguments: vec![],
1070 }],
1071 );
1072
1073 let result = evaluator.evaluate_directives_with_context(&field, &context).unwrap();
1075 assert_eq!(result, DirectiveResult::Include);
1076 }
1077
1078 #[test]
1079 fn test_custom_directive_builtin_skip() {
1080 let evaluator = CustomDirectiveEvaluator::new();
1081 let context = EvaluationContext::new(HashMap::new());
1082
1083 let field = make_field("email", vec![make_directive("skip", "true")]);
1084 let result = evaluator.evaluate_directives_with_context(&field, &context).unwrap();
1085 assert_eq!(result, DirectiveResult::Skip);
1086 }
1087
1088 #[test]
1089 fn test_custom_directive_builtin_include() {
1090 let evaluator = CustomDirectiveEvaluator::new();
1091 let context = EvaluationContext::new(HashMap::new());
1092
1093 let field = make_field("email", vec![make_directive("include", "false")]);
1094 let result = evaluator.evaluate_directives_with_context(&field, &context).unwrap();
1095 assert_eq!(result, DirectiveResult::Skip);
1096 }
1097
1098 #[test]
1099 fn test_filter_selections_with_custom_directive() {
1100 let always_skip = Arc::new(AlwaysSkipDirective);
1101 let evaluator = CustomDirectiveEvaluator::new().with_handler(always_skip);
1102
1103 let selections = vec![
1104 make_field("id", vec![]),
1105 make_field(
1106 "secret",
1107 vec![Directive {
1108 name: "alwaysSkip".to_string(),
1109 arguments: vec![],
1110 }],
1111 ),
1112 make_field("name", vec![]),
1113 ];
1114
1115 let context = EvaluationContext::new(HashMap::new());
1116 let filtered = evaluator.filter_selections_with_context(&selections, &context).unwrap();
1117
1118 assert_eq!(filtered.len(), 2);
1119 assert_eq!(filtered[0].name, "id");
1120 assert_eq!(filtered[1].name, "name");
1121 }
1122
1123 #[test]
1124 fn test_filter_selections_with_error_directive() {
1125 let error = Arc::new(ErrorDirective);
1126 let evaluator = CustomDirectiveEvaluator::new().with_handler(error);
1127
1128 let selections = vec![
1129 make_field("id", vec![]),
1130 make_field(
1131 "broken",
1132 vec![Directive {
1133 name: "error".to_string(),
1134 arguments: vec![],
1135 }],
1136 ),
1137 ];
1138
1139 let context = EvaluationContext::new(HashMap::new());
1140 let result = evaluator.filter_selections_with_context(&selections, &context);
1141
1142 assert!(matches!(result, Err(DirectiveError::CustomDirectiveError(_))));
1143 }
1144
1145 #[test]
1146 fn test_evaluation_context_has_role() {
1147 let context = EvaluationContext::new(HashMap::new()).with_user_context(
1148 "roles",
1149 JsonValue::Array(vec![
1150 JsonValue::String("admin".to_string()),
1151 JsonValue::String("editor".to_string()),
1152 ]),
1153 );
1154
1155 assert!(context.has_role("admin"));
1156 assert!(context.has_role("editor"));
1157 assert!(!context.has_role("viewer"));
1158 }
1159
1160 #[test]
1161 fn test_evaluation_context_user_id() {
1162 let context = EvaluationContext::new(HashMap::new())
1163 .with_user_context("userId", JsonValue::String("user123".to_string()));
1164
1165 assert_eq!(context.user_id(), Some("user123"));
1166 }
1167
1168 #[test]
1169 fn test_evaluation_context_field_path() {
1170 let context = EvaluationContext::new(HashMap::new()).with_field_path("Query.users.email");
1171
1172 assert_eq!(context.field_path.as_deref(), Some("Query.users.email"));
1173 }
1174
1175 #[test]
1176 fn test_evaluation_context_operation_type() {
1177 let context =
1178 EvaluationContext::new(HashMap::new()).with_operation_type(OperationType::Mutation);
1179
1180 assert_eq!(context.operation_type, Some(OperationType::Mutation));
1181 }
1182
1183 #[test]
1184 fn test_directive_result_default() {
1185 assert_eq!(DirectiveResult::default(), DirectiveResult::Include);
1186 }
1187
1188 #[test]
1189 fn test_parse_directive_args() {
1190 let directive = Directive {
1191 name: "test".to_string(),
1192 arguments: vec![
1193 GraphQLArgument {
1194 name: "limit".to_string(),
1195 value_type: "Int".to_string(),
1196 value_json: "10".to_string(),
1197 },
1198 GraphQLArgument {
1199 name: "name".to_string(),
1200 value_type: "String".to_string(),
1201 value_json: "\"hello\"".to_string(),
1202 },
1203 ],
1204 };
1205
1206 let variables = HashMap::new();
1207 let args = DirectiveEvaluator::parse_directive_args(&directive, &variables).unwrap();
1208
1209 assert_eq!(args.get("limit"), Some(&JsonValue::Number(10.into())));
1210 assert_eq!(args.get("name"), Some(&JsonValue::String("hello".to_string())));
1211 }
1212
1213 #[test]
1214 fn test_parse_directive_args_with_variable() {
1215 let directive = Directive {
1216 name: "test".to_string(),
1217 arguments: vec![GraphQLArgument {
1218 name: "limit".to_string(),
1219 value_type: "Int".to_string(),
1220 value_json: "\"$myLimit\"".to_string(),
1221 }],
1222 };
1223
1224 let mut variables = HashMap::new();
1225 variables.insert("myLimit".to_string(), JsonValue::Number(25.into()));
1226
1227 let args = DirectiveEvaluator::parse_directive_args(&directive, &variables).unwrap();
1228 assert_eq!(args.get("limit"), Some(&JsonValue::Number(25.into())));
1229 }
1230
1231 #[test]
1232 fn test_multiple_handlers() {
1233 let auth = Arc::new(AuthDirective {
1234 required_role: "admin".to_string(),
1235 });
1236 let skip = Arc::new(AlwaysSkipDirective);
1237
1238 let evaluator = CustomDirectiveEvaluator::new().with_handlers(vec![auth, skip]);
1239
1240 assert!(evaluator.has_handler("auth"));
1241 assert!(evaluator.has_handler("alwaysSkip"));
1242 assert_eq!(evaluator.handler_names().len(), 2);
1243 }
1244}