#[cfg(test)]
mod tests {
use foldhash::{HashMap, HashMapExt};
use graphql_parser::{
query::{
self, Definition as QueryDefinition, OperationDefinition, Selection, SelectionSet,
},
schema::{self, Definition as SchemaDefinition, Type, TypeDefinition},
};
use crate::subgraph::queries::{CRATE_QUERIES, PUBLIC_QUERIES};
#[derive(Debug, Clone)]
enum FieldType {
Scalar,
Entity(String),
}
fn classify(ty: &Type<'_, String>) -> FieldType {
match ty {
Type::NamedType(name) => {
if is_scalar(name) {
FieldType::Scalar
} else {
FieldType::Entity(name.clone())
}
}
Type::ListType(inner) | Type::NonNullType(inner) => classify(inner),
}
}
fn is_scalar(name: &str) -> bool {
matches!(
name,
"String" | "Int" | "Float" | "Boolean" | "ID" | "Bytes" | "BigInt" | "BigDecimal"
)
}
type SchemaModel = HashMap<String, HashMap<String, FieldType>>;
fn build_schema_model() -> SchemaModel {
let sdl = include_str!("../../specs/subgraph.graphql");
let doc = schema::parse_schema::<String>(sdl)
.unwrap_or_else(|e| panic!("failed to parse subgraph.graphql: {e}"));
let mut model = HashMap::new();
for def in &doc.definitions {
if let SchemaDefinition::TypeDefinition(TypeDefinition::Object(obj)) = def {
let mut fields = HashMap::new();
for f in &obj.fields {
fields.insert(f.name.clone(), classify(&f.field_type));
}
model.insert(obj.name.clone(), fields);
}
}
model
}
fn lint_query(query_str: &str, root_entity: &str, model: &SchemaModel) -> Vec<String> {
let doc = match query::parse_query::<String>(query_str) {
Ok(d) => d,
Err(e) => return vec![format!("parse error: {e}")],
};
let mut errors = Vec::new();
for def in &doc.definitions {
let QueryDefinition::Operation(op) = def else {
errors.push("fragment definitions are not supported by the drift linter".into());
continue;
};
let selection_set = match op {
OperationDefinition::Query(q) => &q.selection_set,
OperationDefinition::SelectionSet(ss) => ss,
OperationDefinition::Mutation(m) => &m.selection_set,
OperationDefinition::Subscription(s) => &s.selection_set,
};
walk_top_level(selection_set, root_entity, model, &mut errors);
}
errors
}
fn walk_top_level(
ss: &SelectionSet<'_, String>,
root: &str,
model: &SchemaModel,
errors: &mut Vec<String>,
) {
for sel in &ss.items {
if let Selection::Field(f) = sel {
walk_selection_set(&f.selection_set, root, model, errors);
}
}
}
fn walk_selection_set(
ss: &SelectionSet<'_, String>,
entity: &str,
model: &SchemaModel,
errors: &mut Vec<String>,
) {
let Some(fields) = model.get(entity) else {
errors.push(format!("unknown entity `{entity}` in schema model"));
return;
};
for sel in &ss.items {
match sel {
Selection::Field(f) => {
let Some(field_type) = fields.get(&f.name) else {
errors.push(format!(
"field `{}` does not exist on entity `{entity}`",
f.name
));
continue;
};
if !f.selection_set.items.is_empty() {
match field_type {
FieldType::Entity(sub_entity) => {
walk_selection_set(&f.selection_set, sub_entity, model, errors);
}
FieldType::Scalar => {
errors.push(format!(
"field `{}` on entity `{entity}` is a scalar but the query \
selects sub-fields",
f.name
));
}
}
}
}
Selection::FragmentSpread(_) | Selection::InlineFragment(_) => {
errors.push(format!(
"fragments are not supported by the drift linter (in entity `{entity}`)"
));
}
}
}
}
#[test]
fn schema_parses_cleanly() {
let model = build_schema_model();
assert!(!model.is_empty(), "SDL produced empty schema model");
}
#[test]
fn every_expected_entity_is_present() {
let model = build_schema_model();
for (name, _, root_entity) in CRATE_QUERIES.iter().chain(PUBLIC_QUERIES) {
assert!(
model.contains_key(*root_entity),
"query `{name}` claims root entity `{root_entity}`, which is absent from \
specs/subgraph.graphql"
);
}
}
#[test]
fn every_crate_query_matches_schema() {
let model = build_schema_model();
let mut all_errors: Vec<String> = Vec::new();
for (name, query, root) in CRATE_QUERIES.iter().chain(PUBLIC_QUERIES) {
let errors = lint_query(query, root, &model);
for err in errors {
all_errors.push(format!("[{name}] {err}"));
}
}
assert!(
all_errors.is_empty(),
"subgraph query drift detected ({} issues):\n - {}",
all_errors.len(),
all_errors.join("\n - ")
);
}
#[test]
fn linter_catches_missing_field() {
let model = build_schema_model();
let bad = "query { totals { tokens definitely_not_a_field } }";
let errors = lint_query(bad, "Total", &model);
assert!(
errors.iter().any(|e| e.contains("definitely_not_a_field")),
"linter must flag unknown fields; got: {errors:?}"
);
}
#[test]
fn linter_catches_scalar_with_subselection() {
let model = build_schema_model();
let bad = "query { totals { tokens { nope } } }";
let errors = lint_query(bad, "Total", &model);
assert!(
errors.iter().any(|e| e.contains("scalar")),
"linter must flag sub-selections on scalar fields; got: {errors:?}"
);
}
#[test]
fn linter_accepts_minimal_valid_query() {
let model = build_schema_model();
let good = "query { totals { tokens orders traders } }";
assert!(lint_query(good, "Total", &model).is_empty());
}
#[test]
fn linter_catches_unknown_entity() {
let model = build_schema_model();
let bad = "query { totals { tokens } }";
let errors = lint_query(bad, "NonExistentEntity", &model);
assert!(
errors.iter().any(|e| e.contains("unknown entity")),
"linter must flag unknown entity; got: {errors:?}"
);
}
#[test]
fn linter_catches_fragment_spread() {
let model = build_schema_model();
let bad = "query { totals { ...TotalFields } } fragment TotalFields on Total { tokens }";
let errors = lint_query(bad, "Total", &model);
assert!(
errors.iter().any(|e| e.contains("fragment")),
"linter must flag fragments; got: {errors:?}"
);
}
#[test]
fn linter_handles_parse_error() {
let model = build_schema_model();
let bad = "not a valid graphql query {{{";
let errors = lint_query(bad, "Total", &model);
assert!(
errors.iter().any(|e| e.contains("parse error")),
"linter must flag parse errors; got: {errors:?}"
);
}
#[test]
fn linter_mutation_and_subscription_paths() {
let model = build_schema_model();
let mutation = "mutation { doSomething { tokens } }";
let errors = lint_query(mutation, "Total", &model);
let _ = errors;
let subscription = "subscription { totals { tokens } }";
let errors = lint_query(subscription, "Total", &model);
let _ = errors;
}
#[test]
fn classify_list_type_entity() {
let ty = Type::ListType(Box::new(Type::NamedType("Token".to_owned())));
let ft = classify(&ty);
assert!(matches!(ft, FieldType::Entity(ref name) if name == "Token"));
}
#[test]
fn classify_nonnull_scalar() {
let ty = Type::NonNullType(Box::new(Type::NamedType("BigInt".to_owned())));
let ft = classify(&ty);
assert!(matches!(ft, FieldType::Scalar));
}
}