use std::fmt;
use std::fmt::Display;
use std::hash::BuildHasher;
use std::hash::Hash;
use std::hash::Hasher;
use std::ops::Deref;
use std::sync::Arc;
use std::sync::OnceLock;
use apollo_compiler::Name;
use apollo_compiler::Node;
use apollo_compiler::collections::IndexSet;
use apollo_compiler::executable;
use serde::Serialize;
use super::DEFER_DIRECTIVE_NAME;
use super::DEFER_LABEL_ARGUMENT_NAME;
use super::sort_arguments;
fn compare_sorted_value(left: &executable::Value, right: &executable::Value) -> std::cmp::Ordering {
use apollo_compiler::executable::Value;
fn discriminant(value: &Value) -> u8 {
match value {
Value::Null => 0,
Value::Enum(_) => 1,
Value::Variable(_) => 2,
Value::String(_) => 3,
Value::Float(_) => 4,
Value::Int(_) => 5,
Value::Boolean(_) => 6,
Value::List(_) => 7,
Value::Object(_) => 8,
}
}
match (left, right) {
(Value::Null, Value::Null) => std::cmp::Ordering::Equal,
(Value::Enum(left), Value::Enum(right)) => left.cmp(right),
(Value::Variable(left), Value::Variable(right)) => left.cmp(right),
(Value::String(left), Value::String(right)) => left.cmp(right),
(Value::Float(left), Value::Float(right)) => left.as_str().cmp(right.as_str()),
(Value::Int(left), Value::Int(right)) => left.as_str().cmp(right.as_str()),
(Value::Boolean(left), Value::Boolean(right)) => left.cmp(right),
(Value::List(left), Value::List(right)) => left.len().cmp(&right.len()).then_with(|| {
left.iter()
.zip(right)
.map(|(left, right)| compare_sorted_value(left, right))
.find(|o| o.is_ne())
.unwrap_or(std::cmp::Ordering::Equal)
}),
(Value::Object(left), Value::Object(right)) => compare_sorted_name_value_pairs(
left.iter().map(|pair| &pair.0),
left.iter().map(|pair| &pair.1),
right.iter().map(|pair| &pair.0),
right.iter().map(|pair| &pair.1),
),
_ => discriminant(left).cmp(&discriminant(right)),
}
}
fn compare_sorted_name_value_pairs<'doc>(
left_names: impl ExactSizeIterator<Item = &'doc Name>,
left_values: impl ExactSizeIterator<Item = &'doc Node<executable::Value>>,
right_names: impl ExactSizeIterator<Item = &'doc Name>,
right_values: impl ExactSizeIterator<Item = &'doc Node<executable::Value>>,
) -> std::cmp::Ordering {
left_names
.len()
.cmp(&right_names.len())
.then_with(|| left_names.cmp(right_names))
.then_with(|| {
left_values
.zip(right_values)
.map(|(left, right)| compare_sorted_value(left, right))
.find(|o| o.is_ne())
.unwrap_or(std::cmp::Ordering::Equal)
})
}
fn compare_sorted_arguments(
left: &[Node<executable::Argument>],
right: &[Node<executable::Argument>],
) -> std::cmp::Ordering {
compare_sorted_name_value_pairs(
left.iter().map(|arg| &arg.name),
left.iter().map(|arg| &arg.value),
right.iter().map(|arg| &arg.name),
right.iter().map(|arg| &arg.value),
)
}
static EMPTY_DIRECTIVE_LIST: executable::DirectiveList = executable::DirectiveList(vec![]);
#[derive(Debug, Clone, Serialize)]
struct DirectiveListInner {
#[serde(skip)]
hash: u64,
#[serde(serialize_with = "crate::utils::serde_bridge::serialize_exe_directive_list")]
directives: executable::DirectiveList,
#[serde(skip)]
sort_order: Vec<usize>,
}
impl PartialEq for DirectiveListInner {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash
&& self
.iter_sorted()
.zip(other.iter_sorted())
.all(|(left, right)| {
left.name == right.name && left.arguments == right.arguments
})
}
}
impl Eq for DirectiveListInner {}
impl DirectiveListInner {
fn rehash(&mut self) {
static SHARED_RANDOM: OnceLock<std::hash::RandomState> = OnceLock::new();
let mut state = SHARED_RANDOM.get_or_init(Default::default).build_hasher();
self.len().hash(&mut state);
for d in self.iter_sorted() {
d.hash(&mut state);
}
self.hash = state.finish();
}
fn len(&self) -> usize {
self.directives.len()
}
fn iter_sorted(&self) -> DirectiveIterSorted<'_> {
DirectiveIterSorted {
directives: &self.directives.0,
inner: self.sort_order.iter(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize)]
pub(crate) struct DirectiveList {
inner: Option<Arc<DirectiveListInner>>,
}
impl Deref for DirectiveList {
type Target = executable::DirectiveList;
fn deref(&self) -> &Self::Target {
self.inner
.as_ref()
.map_or(&EMPTY_DIRECTIVE_LIST, |inner| &inner.directives)
}
}
impl Hash for DirectiveList {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.inner.as_ref().map_or(0, |inner| inner.hash))
}
}
impl Display for DirectiveList {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(inner) = &self.inner {
inner.directives.fmt(f)
} else {
Ok(())
}
}
}
impl From<executable::DirectiveList> for DirectiveList {
fn from(mut directives: executable::DirectiveList) -> Self {
if directives.is_empty() {
return Self::new();
}
for directive in directives.iter_mut() {
sort_arguments(&mut directive.make_mut().arguments);
}
let mut sort_order = (0usize..directives.len()).collect::<Vec<_>>();
sort_order.sort_by(|left, right| {
let left = &directives[*left];
let right = &directives[*right];
left.name
.cmp(&right.name)
.then_with(|| compare_sorted_arguments(&left.arguments, &right.arguments))
});
let mut partially_initialized = DirectiveListInner {
hash: 0,
directives,
sort_order,
};
partially_initialized.rehash();
Self {
inner: Some(Arc::new(partially_initialized)),
}
}
}
impl FromIterator<Node<executable::Directive>> for DirectiveList {
fn from_iter<T: IntoIterator<Item = Node<executable::Directive>>>(iter: T) -> Self {
Self::from(executable::DirectiveList::from_iter(iter))
}
}
impl FromIterator<executable::Directive> for DirectiveList {
fn from_iter<T: IntoIterator<Item = executable::Directive>>(iter: T) -> Self {
Self::from(executable::DirectiveList::from_iter(iter))
}
}
impl DirectiveList {
pub(crate) const fn new() -> Self {
Self { inner: None }
}
pub(crate) fn one(directive: impl Into<Node<executable::Directive>>) -> Self {
std::iter::once(directive.into()).collect()
}
#[cfg(test)]
pub(crate) fn parse(input: &str) -> Self {
use apollo_compiler::ast;
let input = format!(
r#"query {{ field
# Directive input:
{input}
#
}}"#
);
let mut parser = apollo_compiler::parser::Parser::new();
let document = parser
.parse_ast(&input, "DirectiveList::parse.graphql")
.unwrap();
let Some(ast::Definition::OperationDefinition(operation)) = document.definitions.first()
else {
unreachable!();
};
let Some(ast::Selection::Field(field)) = operation.selection_set.first() else {
unreachable!();
};
field.directives.clone().into()
}
pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = &Node<executable::Directive>> {
self.inner
.as_ref()
.map_or(&EMPTY_DIRECTIVE_LIST, |inner| &inner.directives)
.iter()
}
pub(crate) fn remove_one(&mut self, name: &str) -> Option<Node<executable::Directive>> {
let Some(inner) = self.inner.as_mut() else {
return None;
};
let index = inner.directives.iter().position(|dir| dir.name == name)?;
if inner.len() == 1 {
let item = inner.directives[index].clone();
self.inner = None;
return Some(item);
}
let inner = Arc::make_mut(inner);
let sort_index = inner
.sort_order
.iter()
.position(|sorted| *sorted == index)
.expect("index must exist in sort order");
let item = inner.directives.remove(index);
inner.sort_order.remove(sort_index);
for order in &mut inner.sort_order {
if *order > index {
*order -= 1;
}
}
inner.rehash();
Some(item)
}
pub(crate) fn remove_defer(&mut self, defer_labels: &IndexSet<String>) {
let label = self
.get(&DEFER_DIRECTIVE_NAME)
.and_then(|directive| directive.specified_argument_by_name(&DEFER_LABEL_ARGUMENT_NAME))
.and_then(|arg| arg.as_str());
if label.is_some_and(|label| defer_labels.contains(label)) {
self.remove_one(&DEFER_DIRECTIVE_NAME);
}
}
}
struct DirectiveIterSorted<'a> {
directives: &'a [Node<executable::Directive>],
inner: std::slice::Iter<'a, usize>,
}
impl<'a> Iterator for DirectiveIterSorted<'a> {
type Item = &'a Node<executable::Directive>;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|index| &self.directives[*index])
}
}
impl ExactSizeIterator for DirectiveIterSorted<'_> {
fn len(&self) -> usize {
self.inner.len()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
#[test]
fn consistent_hash() {
let mut set = HashSet::new();
assert!(set.insert(DirectiveList::new()));
assert!(!set.insert(DirectiveList::new()));
assert!(set.insert(DirectiveList::parse("@a @b")));
assert!(!set.insert(DirectiveList::parse("@b @a")));
}
#[test]
fn order_independent_equality() {
assert_eq!(DirectiveList::new(), DirectiveList::new());
assert_eq!(
DirectiveList::parse("@a @b"),
DirectiveList::parse("@b @a"),
"equality should be order independent"
);
assert_eq!(
DirectiveList::parse("@a(arg1: true, arg2: false) @b(arg2: false, arg1: true)"),
DirectiveList::parse("@b(arg1: true, arg2: false) @a(arg1: true, arg2: false)"),
"arguments should be order independent"
);
assert_eq!(
DirectiveList::parse("@nested(object: { a: 1, b: 2, c: 3 })"),
DirectiveList::parse("@nested(object: { b: 2, c: 3, a: 1 })"),
"input objects should be order independent"
);
assert_eq!(
DirectiveList::parse("@nested(object: [true, { a: 1, b: 2, c: { a: 3 } }])"),
DirectiveList::parse("@nested(object: [true, { b: 2, c: { a: 3 }, a: 1 }])"),
"input objects should be order independent"
);
}
}