1#![allow(clippy::panic)]
2
3use std::collections::HashMap;
4
5use async_graphql::{
6 ServerError,
7 dynamic::{FieldValue, ResolverContext},
8};
9use cynic_parser::{common::WrappingType, type_system as parser};
10use serde::Deserialize;
11
12use super::{
13 GraphqlSubgraph,
14 entity_resolver::{EntityResolver, EntityResolverContext},
15 resolver::Resolver,
16};
17
18type ResolverMap = HashMap<(String, String), Box<dyn Resolver>>;
19type EntityResolverMap = HashMap<String, Box<dyn EntityResolver>>;
20
21pub struct GraphqlSubgraphBuilder {
23 sdl: String,
24 name: String,
25 field_resolvers: ResolverMap,
26 entity_resolvers: EntityResolverMap,
27}
28
29impl GraphqlSubgraphBuilder {
30 pub(crate) fn new(sdl: String, name: String) -> Self {
31 GraphqlSubgraphBuilder {
32 sdl,
33 name,
34 field_resolvers: Default::default(),
35 entity_resolvers: Default::default(),
36 }
37 }
38
39 pub fn with_name(mut self, name: impl AsRef<str>) -> Self {
41 self.name = name.as_ref().to_string();
42 self
43 }
44
45 pub fn with_resolver(mut self, ty: &str, field: &str, resolver: impl Resolver + 'static) -> Self {
52 self.field_resolvers
53 .insert((ty.into(), field.into()), Box::new(resolver));
54 self
55 }
56
57 pub fn with_entity_resolver(mut self, entity: &str, resolver: impl EntityResolver + 'static) -> Self {
63 self.entity_resolvers.insert(entity.into(), Box::new(resolver));
64 self
65 }
66
67 pub fn build(self) -> GraphqlSubgraph {
69 let Self {
70 sdl,
71 name,
72 mut field_resolvers,
73 entity_resolvers,
74 } = self;
75
76 let schema = cynic_parser::parse_type_system_document(&sdl)
77 .map_err(|e| e.to_report(&sdl))
78 .expect("a valid document");
79
80 let (query_type, ..) = root_types(&schema);
81
82 let mut builder = schema_builder(&schema);
84 builder = builder.register(service_type(&sdl));
85 let entities = find_entities(&schema);
86
87 if !entities.is_empty() {
88 builder = builder.register(any_type());
89 builder = builder.register(entity_type(&entities));
90 }
91
92 let mut entity_resolvers = Some(entity_resolvers);
93 for definition in schema.definitions() {
94 match definition {
95 parser::Definition::Type(def) => {
96 let mut ty = convert_type(def, &mut field_resolvers);
97 if def.name() == query_type
98 && let Some(entity_resolvers) = entity_resolvers.take()
99 {
100 ty = add_federation_fields(ty, &entities, entity_resolvers);
101 }
102 builder = builder.register(ty);
103 }
104 parser::Definition::TypeExtension(_) => {
105 unimplemented!("this is just for tests, extensions aren't supported")
106 }
107 _ => {}
108 }
109 }
110
111 if let Some(entity_resolvers) = entity_resolvers
112 && !entities.is_empty()
113 {
114 builder = builder.register(add_federation_fields(
115 async_graphql::dynamic::Object::new("Query").into(),
116 &entities,
117 entity_resolvers,
118 ));
119 }
120
121 let executable_schema = builder.finish().unwrap();
122
123 GraphqlSubgraph {
124 executable_schema,
125 schema: sdl,
126 name,
127 }
128 }
129}
130
131fn find_entities(schema: &parser::TypeSystemDocument) -> Vec<&str> {
132 schema
133 .definitions()
134 .filter_map(|def| match def {
135 parser::Definition::Type(def) => Some(def),
136 parser::Definition::TypeExtension(def) => Some(def),
137 _ => None,
138 })
139 .filter(|def| def.directives().any(|directive| directive.name() == "key"))
140 .map(|def| def.name())
141 .collect()
142}
143
144fn convert_type(def: parser::TypeDefinition<'_>, resolvers: &mut ResolverMap) -> async_graphql::dynamic::Type {
145 match def {
146 parser::TypeDefinition::Scalar(def) => async_graphql::dynamic::Scalar::new(def.name()).into(),
147 parser::TypeDefinition::Object(def) => convert_object(def, resolvers),
148 parser::TypeDefinition::Interface(def) => convert_iface(def),
149 parser::TypeDefinition::Union(def) => convert_union(def),
150 parser::TypeDefinition::Enum(def) => convert_enum(def),
151 parser::TypeDefinition::InputObject(def) => convert_input_object(def),
152 }
153}
154
155fn convert_object(def: parser::ObjectDefinition<'_>, resolvers: &mut ResolverMap) -> async_graphql::dynamic::Type {
156 use async_graphql::dynamic::*;
157
158 let mut object = Object::new(def.name());
159
160 for name in def.implements_interfaces() {
161 object = object.implement(name);
162 }
163
164 for field_def in def.fields() {
165 let type_ref = convert_type_ref(field_def.ty());
166 let resolver = std::sync::Mutex::new(
167 resolvers
168 .remove(&(def.name().into(), field_def.name().into()))
169 .unwrap_or_else(|| Box::new(default_field_resolver(field_def.name()))),
170 );
171
172 let mut field = Field::new(field_def.name(), type_ref, move |context| {
173 let mut resolver = resolver.lock().expect("mutex to be unpoisoned");
174 FieldFuture::Value(resolver.resolve(context).map(|value| {
175 let value = async_graphql::Value::deserialize(value).unwrap();
176 transform_into_field_value(value)
177 }))
178 });
179
180 for argument in field_def.arguments() {
181 field = field.argument(convert_input_value(argument));
182 }
183
184 object = object.field(field);
185 }
186
187 object.into()
188}
189
190fn transform_into_field_value(mut value: async_graphql::Value) -> FieldValue<'static> {
191 match value {
192 async_graphql::Value::Object(ref mut fields) => {
193 if let Some(async_graphql::Value::String(ty)) = fields.swap_remove("__typename") {
194 FieldValue::from(value).with_type(ty)
195 } else {
196 FieldValue::from(value)
197 }
198 }
199 async_graphql::Value::List(values) => FieldValue::list(values.into_iter().map(transform_into_field_value)),
200 value => FieldValue::from(value),
201 }
202}
203
204fn convert_iface(def: parser::InterfaceDefinition<'_>) -> async_graphql::dynamic::Type {
205 use async_graphql::dynamic::*;
206 let mut interface = Interface::new(def.name());
207
208 for field_def in def.fields() {
209 let type_ref = convert_type_ref(field_def.ty());
210
211 let mut field = InterfaceField::new(field_def.name(), type_ref);
212
213 for argument in field_def.arguments() {
214 field = field.argument(convert_input_value(argument));
215 }
216
217 interface = interface.field(field);
218 }
219
220 interface.into()
221}
222
223fn convert_union(def: parser::UnionDefinition<'_>) -> async_graphql::dynamic::Type {
224 use async_graphql::dynamic::*;
225
226 let mut output = Union::new(def.name());
227
228 for member in def.members() {
229 output = output.possible_type(member.name());
230 }
231
232 output.into()
233}
234
235fn convert_enum(def: parser::EnumDefinition<'_>) -> async_graphql::dynamic::Type {
236 use async_graphql::dynamic::*;
237
238 Enum::new(def.name())
239 .items(def.values().map(|value| EnumItem::new(value.value())))
240 .into()
241}
242
243fn convert_input_object(def: parser::InputObjectDefinition<'_>) -> async_graphql::dynamic::Type {
244 use async_graphql::dynamic::*;
245
246 let mut object = InputObject::new(def.name());
247
248 for field_def in def.fields() {
249 object = object.field(convert_input_value(field_def))
250 }
251
252 object.into()
253}
254
255fn convert_type_ref(ty: parser::Type<'_>) -> async_graphql::dynamic::TypeRef {
256 use async_graphql::dynamic::TypeRef;
257
258 let mut output = TypeRef::named(ty.name());
259
260 for wrapper in ty.wrappers() {
261 match wrapper {
262 WrappingType::NonNull => {
263 output = TypeRef::NonNull(Box::new(output));
264 }
265 WrappingType::List => {
266 output = TypeRef::List(Box::new(output));
267 }
268 }
269 }
270
271 output
272}
273
274fn convert_input_value(value_def: parser::InputValueDefinition<'_>) -> async_graphql::dynamic::InputValue {
275 use async_graphql::dynamic::InputValue;
276
277 let mut value = InputValue::new(value_def.name(), convert_type_ref(value_def.ty()));
278
279 if let Some(default) = value_def.default_value() {
280 value = value.default_value(convert_value(default))
281 }
282
283 value
284}
285
286fn convert_value(value: cynic_parser::ConstValue<'_>) -> async_graphql::Value {
287 match value {
288 cynic_parser::ConstValue::Int(inner) => async_graphql::Value::Number(inner.as_i64().into()),
289 cynic_parser::ConstValue::Float(inner) => {
290 async_graphql::Value::Number(serde_json::Number::from_f64(inner.as_f64()).unwrap())
291 }
292 cynic_parser::ConstValue::String(inner) => async_graphql::Value::String(inner.as_str().into()),
293 cynic_parser::ConstValue::Boolean(inner) => async_graphql::Value::Boolean(inner.as_bool()),
294 cynic_parser::ConstValue::Null(_) => async_graphql::Value::Null,
295 cynic_parser::ConstValue::Enum(inner) => async_graphql::Value::Enum(async_graphql::Name::new(inner.name())),
296 cynic_parser::ConstValue::List(inner) => async_graphql::Value::List(inner.items().map(convert_value).collect()),
297 cynic_parser::ConstValue::Object(inner) => async_graphql::Value::Object(
298 inner
299 .fields()
300 .map(|field| (async_graphql::Name::new(field.name()), convert_value(field.value())))
301 .collect(),
302 ),
303 }
304}
305
306fn schema_builder(schema: &cynic_parser::TypeSystemDocument) -> async_graphql::dynamic::SchemaBuilder {
307 let (query_name, mutation_name, subscription_name) = root_types(schema);
308 async_graphql::dynamic::Schema::build(query_name, mutation_name, subscription_name)
309}
310
311fn root_types(schema: &cynic_parser::TypeSystemDocument) -> (&str, Option<&str>, Option<&str>) {
312 use parser::Definition;
313
314 let mut query_name = "Query";
315 let mut mutation_name = None;
316 let mut subscription_name = None;
317 let mut found_schema_def = false;
318 let mut mutation_present = false;
319 let mut subscription_present = false;
320 for definition in schema.definitions() {
321 if let Definition::Schema(_) = definition {
322 found_schema_def = true;
323 }
324 match definition {
325 Definition::Schema(schema) | Definition::SchemaExtension(schema) => {
326 if let Some(def) = schema.query_type() {
327 query_name = def.named_type();
328 }
329 if let Some(def) = schema.mutation_type() {
330 mutation_name = Some(def.named_type());
331 }
332 if let Some(def) = schema.subscription_type() {
333 subscription_name = Some(def.named_type());
334 }
335 }
336 Definition::Type(ty) | Definition::TypeExtension(ty) if ty.name() == "Mutation" => mutation_present = true,
337 Definition::Type(ty) | Definition::TypeExtension(ty) if ty.name() == "Subscription" => {
338 subscription_present = true
339 }
340 _ => {}
341 }
342 }
343 if !found_schema_def {
344 if mutation_present {
345 mutation_name = Some("Mutation");
346 }
347 if subscription_present {
348 mutation_name = Some("Subscription");
349 }
350 }
351
352 (query_name, mutation_name, subscription_name)
353}
354
355fn default_field_resolver(field_name: &str) -> impl Resolver + 'static {
356 let field_name = async_graphql::Name::new(field_name);
357
358 move |context: ResolverContext<'_>| {
359 if let Some(value) = context.parent_value.as_value() {
360 return match value {
361 async_graphql::Value::Object(map) => {
362 map.get(&field_name).map(|value| value.clone().into_json().unwrap())
363 }
364 _ => None,
365 };
366 }
367 panic!("Unexpected parent value for tests: {:?}", context.parent_value)
368 }
369}
370
371fn service_type(sdl: &str) -> async_graphql::dynamic::Type {
372 use async_graphql::dynamic::*;
373 let mut object = Object::new("_Service");
374
375 let sdl = sdl.to_string();
376
377 object = object.field(Field::new("sdl", TypeRef::named_nn("String"), move |_| {
378 FieldFuture::from_value(Some(async_graphql::Value::String(sdl.clone())))
379 }));
380
381 object.into()
382}
383
384fn entity_type(entities: &[&str]) -> async_graphql::dynamic::Type {
385 use async_graphql::dynamic::*;
386
387 let mut ty = Union::new("_Entity");
388
389 for entity in entities {
390 ty = ty.possible_type(*entity);
391 }
392
393 ty.into()
394}
395
396fn any_type() -> async_graphql::dynamic::Type {
397 use async_graphql::dynamic::*;
398
399 Scalar::new("_Any").into()
400}
401
402fn add_federation_fields(
403 query_ty: async_graphql::dynamic::Type,
404 entities: &[&str],
405 entity_resolvers: EntityResolverMap,
406) -> async_graphql::dynamic::Type {
407 use async_graphql::dynamic::*;
408
409 let async_graphql::dynamic::Type::Object(mut obj) = query_ty else {
410 panic!("this shouldn't happen probably")
411 };
412 obj = obj.field(Field::new("_service", TypeRef::named_nn("_Service"), |_| {
413 FieldFuture::from_value(Some(async_graphql::Value::Object([].into())))
415 }));
416
417 for entity in entity_resolvers.keys() {
418 if !entities.contains(&entity.as_str()) {
419 panic!("Tried to add an resolver for {entity}, but this entity doesnt exist");
420 }
421 }
422
423 if !entities.is_empty() {
424 let resolvers = std::sync::Mutex::new(entity_resolvers);
425
426 let entity_field = Field::new("_entities", TypeRef::named_list_nn("_Entity"), move |context| {
427 let mut resolvers = resolvers.lock().expect("mutex to be unpoisoned");
428 let representations = context
429 .args
430 .get("representations")
431 .expect("_entities needs representations");
432
433 let reprs = representations
434 .deserialize::<Vec<serde_json::Value>>()
435 .expect("representations to be a list of objects");
436
437 let entities = reprs.into_iter().map(|repr| {
438 let context = EntityResolverContext::new(&context, repr);
439
440 let typename = context.typename.clone();
441 let Some(resolver) = resolvers.get_mut(&context.typename) else {
442 context.add_error(ServerError::new(format!("{} has no resolver", context.typename), None));
443 return FieldValue::NULL;
444 };
445
446 let json_value = resolver.resolve(context).unwrap_or_default();
447
448 transform_into_field_value(async_graphql::Value::deserialize(json_value).unwrap()).with_type(typename)
449 });
450
451 FieldFuture::Value(Some(FieldValue::list(entities)))
452 })
453 .argument(InputValue::new("representations", TypeRef::named_nn_list_nn("_Any")));
454
455 obj = obj.field(entity_field);
456 }
457
458 obj.into()
459}