use crate::error::{Error, Result};
use crate::federation::{EntityResolver, FederationConfig, GrpcEntityResolver};
use crate::graphql::{
GraphqlField, GraphqlLiveQuery, GraphqlResponse, GraphqlSchema, GraphqlService, GraphqlType,
};
use crate::grpc_client::{GrpcClient, GrpcClientPool};
use crate::headers::HeaderPropagationConfig;
use crate::live_query::{LiveQueryOperationConfig, LiveQueryStrategy};
use crate::rest_connector::{RestConnectorRegistry, RestFieldType, RestResponseSchema};
use ahash::AHashMap;
use async_graphql::dynamic::{
Enum, EnumItem, Field, FieldFuture, FieldValue, InputObject, InputValue, Object,
ResolverContext, Scalar, Schema as AsyncSchema, Subscription, SubscriptionField,
SubscriptionFieldFuture, TypeRef,
};
use async_graphql::futures_util::StreamExt;
use async_graphql::indexmap::IndexMap;
use async_graphql::{Name, UploadValue, Value as GqlValue};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use prost::bytes::Buf;
use prost::Message;
use prost_reflect::{
DescriptorPool, DynamicMessage, EnumDescriptor, ExtensionDescriptor, FieldDescriptor, Kind,
MapKey, MessageDescriptor, ReflectMessage, Value,
};
use std::collections::{HashMap, HashSet};
use std::io::Read;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tonic::client::Grpc;
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
use tonic::codegen::http;
use tonic::Status;
#[derive(Clone)]
pub struct DynamicSchema {
inner: AsyncSchema,
live_query_configs: Arc<AHashMap<String, LiveQueryOperationConfig>>,
}
impl DynamicSchema {
pub async fn execute(&self, request: async_graphql::Request) -> async_graphql::Response {
self.inner.execute(request).await
}
pub fn executor(&self) -> AsyncSchema {
self.inner.clone()
}
pub fn live_query_configs(&self) -> &AHashMap<String, LiveQueryOperationConfig> {
&self.live_query_configs
}
}
pub struct SchemaBuilder {
descriptor_sets: Vec<Vec<u8>>,
federation: bool,
entity_resolver: Option<std::sync::Arc<dyn EntityResolver>>,
service_allowlist: Option<HashSet<String>>,
query_depth_limit: Option<usize>,
query_complexity_limit: Option<usize>,
introspection_disabled: bool,
header_propagation_config: Option<HeaderPropagationConfig>,
rest_connectors: Option<RestConnectorRegistry>,
}
impl SchemaBuilder {
pub fn new() -> Self {
Self {
descriptor_sets: Vec::new(),
federation: false,
entity_resolver: None,
service_allowlist: None,
query_depth_limit: None,
query_complexity_limit: None,
introspection_disabled: false,
header_propagation_config: None,
rest_connectors: None,
}
}
pub fn with_descriptor_set_bytes(mut self, bytes: impl AsRef<[u8]>) -> Self {
self.descriptor_sets.clear();
self.descriptor_sets.push(bytes.as_ref().to_vec());
self
}
pub fn add_descriptor_set_bytes(mut self, bytes: impl AsRef<[u8]>) -> Self {
self.descriptor_sets.push(bytes.as_ref().to_vec());
self
}
pub fn with_descriptor_set_file(mut self, path: impl AsRef<Path>) -> Result<Self> {
let data = std::fs::read(path).map_err(Error::Io)?;
self.descriptor_sets.clear();
self.descriptor_sets.push(data);
Ok(self)
}
pub fn add_descriptor_set_file(mut self, path: impl AsRef<Path>) -> Result<Self> {
let data = std::fs::read(path).map_err(Error::Io)?;
self.descriptor_sets.push(data);
Ok(self)
}
pub fn enable_federation(mut self) -> Self {
self.federation = true;
self
}
pub fn with_entity_resolver(mut self, resolver: std::sync::Arc<dyn EntityResolver>) -> Self {
self.entity_resolver = Some(resolver);
self
}
pub fn with_services<I, S>(mut self, services: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.service_allowlist = Some(services.into_iter().map(Into::into).collect());
self
}
pub fn with_query_depth_limit(mut self, max_depth: usize) -> Self {
self.query_depth_limit = Some(max_depth);
self
}
pub fn with_query_complexity_limit(mut self, max_complexity: usize) -> Self {
self.query_complexity_limit = Some(max_complexity);
self
}
pub fn disable_introspection(mut self) -> Self {
self.introspection_disabled = true;
self
}
pub fn with_header_propagation(mut self, config: HeaderPropagationConfig) -> Self {
self.header_propagation_config = Some(config);
self
}
pub fn with_rest_connectors(mut self, registry: RestConnectorRegistry) -> Self {
self.rest_connectors = Some(registry);
self
}
pub fn build(self, client_pool: &GrpcClientPool) -> Result<DynamicSchema> {
if self.descriptor_sets.is_empty() {
return Err(Error::Schema(
"at least one descriptor set is required".into(),
));
}
let pool = self.build_merged_descriptor_pool()?;
let method_ext = pool
.get_extension_by_name("graphql.schema")
.ok_or_else(|| Error::Schema("missing graphql.schema extension".into()))?;
let service_ext = pool
.get_extension_by_name("graphql.service")
.ok_or_else(|| Error::Schema("missing graphql.service extension".into()))?;
let field_ext = pool
.get_extension_by_name("graphql.field")
.ok_or_else(|| Error::Schema("missing graphql.field extension".into()))?;
let entity_ext = if self.federation {
pool.get_extension_by_name("graphql.entity")
} else {
None
};
let federation_config = if let Some(entity_ext) = entity_ext.as_ref() {
FederationConfig::from_descriptor_pool(&pool, entity_ext)?
} else {
FederationConfig::new()
};
let live_query_ext = pool.get_extension_by_name("graphql.live_query");
let mut live_query_configs = AHashMap::new();
let mut registry = TypeRegistry::default();
let mut query_root: Option<Object> = None;
let mut mutation_root: Option<Object> = None;
let mut subscription_root: Option<Subscription> = None;
for service in pool.services() {
if let Some(allowlist) = self.service_allowlist.as_ref() {
if !allowlist.contains(service.full_name()) {
continue;
}
}
let service_options =
decode_extension::<GraphqlService>(&service.options(), &service_ext)?
.unwrap_or_default();
if !service_options.host.is_empty() && client_pool.get(service.full_name()).is_none() {
let client = GrpcClient::connect_lazy(
service_options.host.clone(),
service_options.insecure,
)?;
client_pool.add(service.full_name(), client);
}
for method in service.methods() {
let Some(schema_opts) =
decode_extension::<GraphqlSchema>(&method.options(), &method_ext)?
else {
continue;
};
let graphql_type =
GraphqlType::try_from(schema_opts.r#type).unwrap_or(GraphqlType::Query);
let field_name = schema_opts.name.clone();
if field_name.is_empty() {
continue;
}
match graphql_type {
GraphqlType::Query | GraphqlType::Resolver => {
let field = build_field(
field_name.clone(),
&service,
&method,
&schema_opts,
field_ext.clone(),
&mut registry,
client_pool.clone(),
self.header_propagation_config.clone(),
)?;
let mut query = query_root.take().unwrap_or_else(|| Object::new("Query"));
query = query.field(field);
query_root = Some(query);
if let Some(ext) = &live_query_ext {
if let Some(live_opts) =
decode_extension::<GraphqlLiveQuery>(&method.options(), ext)?
{
if live_opts.enabled {
let strategy = crate::graphql::LiveQueryStrategy::try_from(
live_opts.strategy,
)
.ok()
.map(|s| match s {
crate::graphql::LiveQueryStrategy::Invalidation => {
LiveQueryStrategy::Invalidation
}
crate::graphql::LiveQueryStrategy::Polling => {
LiveQueryStrategy::Polling
}
crate::graphql::LiveQueryStrategy::HashDiff => {
LiveQueryStrategy::HashDiff
}
})
.unwrap_or(LiveQueryStrategy::Invalidation);
let config = LiveQueryOperationConfig {
operation_name: field_name.clone(),
enabled: true,
throttle_ms: live_opts.throttle_ms,
triggers: live_opts.triggers,
max_connections: live_opts.max_connections,
ttl_seconds: live_opts.ttl_seconds,
strategy,
poll_interval_ms: live_opts.poll_interval_ms,
depends_on: live_opts.depends_on,
};
live_query_configs.insert(field_name, config);
}
}
}
}
GraphqlType::Mutation => {
let field = build_field(
field_name,
&service,
&method,
&schema_opts,
field_ext.clone(),
&mut registry,
client_pool.clone(),
self.header_propagation_config.clone(),
)?;
let mut mutation = mutation_root
.take()
.unwrap_or_else(|| Object::new("Mutation"));
mutation = mutation.field(field);
mutation_root = Some(mutation);
}
GraphqlType::Subscription => {
let field = build_subscription_field(
field_name,
&service,
&method,
&schema_opts,
field_ext.clone(),
&mut registry,
client_pool.clone(),
self.header_propagation_config.clone(),
)?;
let mut subscription = subscription_root
.take()
.unwrap_or_else(|| Subscription::new("Subscription"));
subscription = subscription.field(field);
subscription_root = Some(subscription);
}
}
}
}
let mut rest_schemas = Vec::new();
if let Some(ref rest_registry) = self.rest_connectors {
let (rest_queries, rest_mutations) = rest_registry.build_graphql_fields();
let query_count = rest_queries.len();
let mutation_count = rest_mutations.len();
for rest_field in rest_queries {
if let Some(schema) = &rest_field.response_schema {
rest_schemas.push(schema.clone());
}
let field = build_rest_field(&rest_field);
let mut query = query_root.take().unwrap_or_else(|| Object::new("Query"));
query = query.field(field);
query_root = Some(query);
}
for rest_field in rest_mutations {
if let Some(schema) = &rest_field.response_schema {
rest_schemas.push(schema.clone());
}
let field = build_rest_field(&rest_field);
let mut mutation = mutation_root
.take()
.unwrap_or_else(|| Object::new("Mutation"));
mutation = mutation.field(field);
mutation_root = Some(mutation);
}
tracing::info!(
"Added {} REST query fields and {} REST mutation fields to schema",
query_count,
mutation_count
);
}
let query_root = query_root.unwrap_or_else(placeholder_query_root);
let mut schema_builder = AsyncSchema::build(
query_root.type_name(),
mutation_root.as_ref().map(Object::type_name),
subscription_root.as_ref().map(Subscription::type_name),
);
schema_builder = schema_builder.enable_uploading();
if self.rest_connectors.is_some() {
let json_scalar =
Scalar::new("JSON").description("Arbitrary JSON value returned from REST APIs");
schema_builder = schema_builder.register(json_scalar);
let mut registered_types = std::collections::HashSet::new();
for schema in rest_schemas {
if !registered_types.contains(&schema.type_name) {
let obj = build_rest_response_type(&schema);
schema_builder = schema_builder.register(obj);
registered_types.insert(schema.type_name.clone());
}
}
}
if self.federation {
schema_builder = schema_builder.enable_federation();
if federation_config.is_enabled() {
let config = federation_config.clone();
let resolver = self.entity_resolver.clone().unwrap_or_else(|| {
std::sync::Arc::new(GrpcEntityResolver::new(client_pool.clone()))
});
schema_builder = schema_builder.entity_resolver(move |ctx: ResolverContext<'_>| {
let entity_resolvers = resolver.clone();
let config = config.clone();
FieldFuture::new(async move {
let representations = ctx
.args
.get("representations")
.ok_or_else(|| {
async_graphql::Error::new("missing representations argument")
})?
.list()?;
let mut results = Vec::new();
for repr in representations.iter() {
let obj = repr.object()?;
let mut representation_map = IndexMap::new();
for (key, value) in obj.iter() {
representation_map.insert(key.clone(), value.as_value().clone());
}
let typename = representation_map
.get(&Name::new("__typename"))
.and_then(|v| match v {
GqlValue::String(s) => Some(s.as_str()),
_ => None,
})
.ok_or_else(|| {
async_graphql::Error::new(
"missing __typename in representation",
)
})?;
let entity_config = config.entities.get(typename).ok_or_else(|| {
async_graphql::Error::new(format!(
"unknown entity type: {}",
typename
))
})?;
let entity = entity_resolvers
.resolve_entity(entity_config, &representation_map)
.await
.map_err(|e| async_graphql::Error::new(e.to_string()))?;
results.push(
FieldValue::value(entity)
.with_type(entity_config.type_name.clone()),
);
}
Ok(Some(FieldValue::list(results)))
})
});
}
}
schema_builder = schema_builder.register(query_root);
if let Some(mutation) = mutation_root {
schema_builder = schema_builder.register(mutation);
}
if let Some(subscription) = subscription_root {
schema_builder = schema_builder.register(subscription);
}
for (_, en) in registry.enums {
schema_builder = schema_builder.register(en);
}
for (_, input) in registry.input_objects {
schema_builder = schema_builder.register(input);
}
for (type_name, obj) in registry.objects {
let obj = if self.federation {
federation_config.apply_directives_to_object(obj, &type_name)?
} else {
obj
};
schema_builder = schema_builder.register(obj);
}
if let Some(max_depth) = self.query_depth_limit {
schema_builder = schema_builder.limit_depth(max_depth);
}
if let Some(max_complexity) = self.query_complexity_limit {
schema_builder = schema_builder.limit_complexity(max_complexity);
}
if self.introspection_disabled {
schema_builder = schema_builder.disable_introspection();
}
let schema = schema_builder
.finish()
.map_err(|e| Error::Schema(format!("failed to build schema: {e}")))?;
Ok(DynamicSchema {
inner: schema,
live_query_configs: Arc::new(live_query_configs),
})
}
fn build_merged_descriptor_pool(&self) -> Result<DescriptorPool> {
let mut iter = self.descriptor_sets.iter();
let first = iter
.next()
.ok_or_else(|| Error::Schema("at least one descriptor set is required".into()))?;
let mut pool = DescriptorPool::decode(first.as_slice())
.map_err(|e| Error::Schema(format!("failed to decode primary descriptor set: {e}")))?;
for (index, bytes) in iter.enumerate() {
pool.decode_file_descriptor_set(bytes.as_slice())
.map_err(|e| {
Error::Schema(format!(
"failed to merge descriptor set #{}: {e}",
index + 2
))
})?;
tracing::debug!(
"Merged descriptor set #{} ({} bytes) into schema pool",
index + 2,
bytes.len()
);
}
if self.descriptor_sets.len() > 1 {
tracing::info!(
"Merged {} descriptor sets into unified schema ({} services, {} types)",
self.descriptor_sets.len(),
pool.services().count(),
pool.all_messages().count()
);
}
Ok(pool)
}
pub fn descriptor_count(&self) -> usize {
self.descriptor_sets.len()
}
}
impl Default for SchemaBuilder {
fn default() -> Self {
Self::new()
}
}
fn build_rest_field(rest_field: &crate::rest_connector::RestGraphQLField) -> Field {
let field_name = rest_field.name.clone();
let description = rest_field.description.clone();
let connector = rest_field.connector.clone();
let endpoint_name = rest_field.endpoint_name.clone();
let parameters = rest_field.parameters.clone();
let return_type = if let Some(schema) = &rest_field.response_schema {
schema.type_name.clone()
} else {
"JSON".to_string()
};
let has_schema = rest_field.response_schema.is_some();
let mut field = Field::new(
field_name.clone(),
TypeRef::named(return_type),
move |ctx: ResolverContext<'_>| {
let connector = connector.clone();
let endpoint_name = endpoint_name.clone();
let parameters = parameters.clone();
FieldFuture::new(async move {
let mut args = std::collections::HashMap::new();
for param in ¶meters {
if let Ok(accessor) = ctx.args.try_get(param.as_str()) {
let json_value = gql_value_to_json(accessor.as_value().clone());
args.insert(param.clone(), json_value);
}
}
match connector.execute(&endpoint_name, args).await {
Ok(result) => {
if has_schema {
Ok(Some(FieldValue::owned_any(result)))
} else {
let gql_value = json_to_gql_value(result);
Ok(Some(FieldValue::value(gql_value)))
}
}
Err(e) => Err(async_graphql::Error::new(format!(
"REST endpoint {} failed: {}",
endpoint_name, e
))),
}
})
},
);
field = field.description(description);
for param in &rest_field.parameters {
field = field.argument(InputValue::new(param.clone(), TypeRef::named("String")));
}
field
}
fn gql_value_to_json(value: GqlValue) -> serde_json::Value {
match value {
GqlValue::Null => serde_json::Value::Null,
GqlValue::Number(n) => {
if let Some(i) = n.as_i64() {
serde_json::Value::Number(i.into())
} else if let Some(f) = n.as_f64() {
serde_json::Number::from_f64(f)
.map(serde_json::Value::Number)
.unwrap_or(serde_json::Value::Null)
} else {
serde_json::Value::Null
}
}
GqlValue::String(s) => serde_json::Value::String(s),
GqlValue::Boolean(b) => serde_json::Value::Bool(b),
GqlValue::Enum(e) => serde_json::Value::String(e.to_string()),
GqlValue::List(arr) => {
serde_json::Value::Array(arr.into_iter().map(gql_value_to_json).collect())
}
GqlValue::Object(obj) => {
let map: serde_json::Map<String, serde_json::Value> = obj
.into_iter()
.map(|(k, v)| (k.to_string(), gql_value_to_json(v)))
.collect();
serde_json::Value::Object(map)
}
GqlValue::Binary(b) => serde_json::Value::String(BASE64.encode(&b)),
}
}
fn json_to_gql_value(value: serde_json::Value) -> GqlValue {
match value {
serde_json::Value::Null => GqlValue::Null,
serde_json::Value::Bool(b) => GqlValue::Boolean(b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
GqlValue::Number(i.into())
} else if let Some(f) = n.as_f64() {
match async_graphql::Number::from_f64(f) {
Some(num) => GqlValue::Number(num),
None => GqlValue::Number(0i64.into()),
}
} else {
GqlValue::Null
}
}
serde_json::Value::String(s) => GqlValue::String(s),
serde_json::Value::Array(arr) => {
GqlValue::List(arr.into_iter().map(json_to_gql_value).collect())
}
serde_json::Value::Object(obj) => {
let map: async_graphql::indexmap::IndexMap<Name, GqlValue> = obj
.into_iter()
.map(|(k, v)| (Name::new(k), json_to_gql_value(v)))
.collect();
GqlValue::Object(map)
}
}
}
fn build_rest_response_type(schema: &RestResponseSchema) -> Object {
let mut object = Object::new(&schema.type_name);
if let Some(desc) = &schema.description {
object = object.description(desc);
}
for field_def in &schema.fields {
let field_name = field_def.name.clone();
let field_type = field_def.field_type.clone();
let type_ref = if field_def.nullable {
TypeRef::named(field_type.to_type_ref())
} else {
TypeRef::named_nn(field_type.to_type_ref())
};
let mut field = Field::new(field_name.clone(), type_ref, move |ctx| {
let field_name = field_name.clone();
let field_type = field_type.clone();
FieldFuture::new(async move {
let parent_value = ctx.parent_value.try_downcast_ref::<serde_json::Value>()?;
let value = match parent_value {
serde_json::Value::Object(map) => map
.get(&field_name)
.cloned()
.unwrap_or(serde_json::Value::Null),
_ => serde_json::Value::Null,
};
match field_type {
RestFieldType::Object(_) | RestFieldType::List(_) => {
Ok(Some(FieldValue::owned_any(value)))
}
_ => {
let gql_value = json_to_gql_typed(value, &field_type);
Ok(Some(FieldValue::value(gql_value)))
}
}
})
});
if let Some(desc) = &field_def.description {
field = field.description(desc);
}
object = object.field(field);
}
object
}
fn json_to_gql_typed(value: serde_json::Value, field_type: &RestFieldType) -> GqlValue {
match (value, field_type) {
(serde_json::Value::Null, _) => GqlValue::Null,
(serde_json::Value::String(s), RestFieldType::String) => GqlValue::String(s),
(v, RestFieldType::String) => GqlValue::String(v.to_string()),
(serde_json::Value::Number(n), RestFieldType::Int) => {
if let Some(i) = n.as_i64() {
GqlValue::Number(i.into())
} else {
GqlValue::Null
}
}
(serde_json::Value::Number(n), RestFieldType::Float) => {
if let Some(f) = n.as_f64() {
match async_graphql::Number::from_f64(f) {
Some(num) => GqlValue::Number(num),
None => GqlValue::Number(0i64.into()),
}
} else {
GqlValue::Null
}
}
(serde_json::Value::Bool(b), RestFieldType::Boolean) => GqlValue::Boolean(b),
(serde_json::Value::Array(arr), RestFieldType::List(inner)) => GqlValue::List(
arr.into_iter()
.map(|v| json_to_gql_typed(v, inner))
.collect(),
),
(v, _) => json_to_gql_value(v),
}
}
#[cfg(test)]
mod tests {
use super::*;
const FEDERATION_DESCRIPTOR: &[u8] =
include_bytes!("generated/federation_example_descriptor.bin");
#[tokio::test]
async fn federation_adds_entities_query() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new(
"{ _entities(representations: []) { __typename } }",
))
.await;
assert!(
response.errors.is_empty(),
"expected _entities field to exist, got errors: {:?}",
response.errors
);
let data = response.data.into_json().expect("valid JSON response");
let entities = data
.get("_entities")
.and_then(|value| value.as_array())
.cloned()
.unwrap_or_default();
assert!(entities.is_empty(), "expected empty entities list");
}
#[tokio::test]
async fn query_depth_limit_blocks_deep_queries() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.with_query_depth_limit(1) .build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new("{ _service { sdl } }"))
.await;
assert!(
!response.errors.is_empty(),
"expected depth limit error, but query succeeded"
);
let error_message = response.errors[0].message.to_lowercase();
assert!(
error_message.contains("nested") || error_message.contains("depth"),
"expected depth error, got: {}",
response.errors[0].message
);
}
#[tokio::test]
async fn query_complexity_limit_blocks_complex_queries() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.with_query_complexity_limit(1) .build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new("{ _service { sdl } }"))
.await;
assert!(
!response.errors.is_empty(),
"expected complexity limit error, but query succeeded"
);
let error_message = response.errors[0].message.to_lowercase();
assert!(
error_message.contains("complex"),
"expected complexity error, got: {}",
response.errors[0].message
);
}
#[tokio::test]
async fn schema_without_limits_allows_queries() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new("{ _service { sdl } }"))
.await;
assert!(
response.errors.is_empty(),
"expected success, got errors: {:?}",
response.errors
);
}
#[tokio::test]
async fn reasonable_limits_allow_normal_queries() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.with_query_depth_limit(10) .with_query_complexity_limit(100) .build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new("{ _service { sdl } }"))
.await;
assert!(
response.errors.is_empty(),
"expected success with reasonable limits, got errors: {:?}",
response.errors
);
}
#[tokio::test]
async fn introspection_disabled_blocks_schema_query() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.disable_introspection()
.build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new(
"{ __schema { types { name } } }",
))
.await;
assert!(
!response.errors.is_empty(),
"expected introspection to be blocked, but query succeeded"
);
let error_message = response.errors[0].message.to_lowercase();
assert!(
error_message.contains("introspection")
|| error_message.contains("disabled")
|| error_message.contains("unknown")
|| error_message.contains("__schema"),
"expected introspection error, got: {}",
response.errors[0].message
);
}
#[tokio::test]
async fn introspection_enabled_allows_schema_query() {
let schema = SchemaBuilder::new()
.with_descriptor_set_bytes(FEDERATION_DESCRIPTOR)
.enable_federation()
.build(&GrpcClientPool::new())
.expect("schema builds");
let response = schema
.execute(async_graphql::Request::new(
"{ __schema { types { name } } }",
))
.await;
assert!(
response.errors.is_empty(),
"expected introspection to work, got errors: {:?}",
response.errors
);
}
}
#[derive(Default)]
struct TypeRegistry {
objects: HashMap<String, Object>,
input_objects: HashMap<String, InputObject>,
enums: HashMap<String, Enum>,
}
#[derive(Clone, Default)]
pub struct GrpcResponseCache {
inner: Arc<Mutex<HashMap<GrpcCacheKey, GqlValue>>>,
}
impl GrpcResponseCache {
pub fn get(&self, key: &GrpcCacheKey) -> Option<GqlValue> {
self.inner.lock().ok().and_then(|map| map.get(key).cloned())
}
pub fn insert(&self, key: GrpcCacheKey, value: GqlValue) {
if let Ok(mut map) = self.inner.lock() {
map.insert(key, value);
}
}
}
#[derive(Hash, Eq, PartialEq, Clone)]
pub struct GrpcCacheKey {
service: String,
path: String,
request: Vec<u8>,
}
impl GrpcCacheKey {
pub fn new(
service: &str,
path: &str,
request: &DynamicMessage,
) -> std::result::Result<Self, prost::EncodeError> {
Ok(Self {
service: service.to_string(),
path: path.to_string(),
request: request.encode_to_vec(),
})
}
}
impl TypeRegistry {
fn type_name_for_message(desc: &MessageDescriptor) -> String {
desc.full_name().replace('.', "_")
}
fn type_name_for_enum(desc: &EnumDescriptor) -> String {
desc.full_name().replace('.', "_")
}
fn ensure_enum(&mut self, desc: &EnumDescriptor) -> TypeRef {
let name = Self::type_name_for_enum(desc);
if !self.enums.contains_key(&name) {
let mut en = Enum::new(name.clone());
for value in desc.values() {
en = en.item(EnumItem::new(value.name()));
}
self.enums.insert(name.clone(), en);
}
TypeRef::named(name)
}
fn ensure_input_object(
&mut self,
message: &MessageDescriptor,
field_ext: &ExtensionDescriptor,
) -> TypeRef {
let name = Self::type_name_for_message(message);
if self.input_objects.contains_key(&name) {
return TypeRef::named(name);
}
let mut input = InputObject::new(name.clone());
for field in message.fields() {
if field_is_omitted(&field, field_ext) {
continue;
}
let field_name = graphql_field_name(&field, field_ext);
let required = field_is_required(&field, field_ext);
let ty = self.input_type_for_field(&field, field_ext, required);
let mut input_val = InputValue::new(field_name, ty);
if let Some(desc) = field_description(&field, field_ext) {
input_val = input_val.description(desc);
}
input = input.field(input_val);
}
self.input_objects.insert(name.clone(), input);
TypeRef::named(name)
}
fn ensure_object(
&mut self,
message: &MessageDescriptor,
field_ext: &ExtensionDescriptor,
) -> TypeRef {
let field_ext = field_ext.clone();
let name = Self::type_name_for_message(message);
if self.objects.contains_key(&name) {
return TypeRef::named(name);
}
let mut obj = Object::new(name.clone());
for field in message.fields() {
if field_is_omitted(&field, &field_ext) {
continue;
}
let field_name = graphql_field_name(&field, &field_ext);
let required = field_is_required(&field, &field_ext);
let ty = self.output_type_for_field(&field, &field_ext, required);
let field_desc = field.clone();
let field_name_for_value = field_name.clone();
let field_ext_for_resolver = field_ext.clone();
let mut gql_field = Field::new(field_name.clone(), ty, move |ctx| {
let field_ext = field_ext_for_resolver.clone();
let field_desc = field_desc.clone();
let field_name_for_value = field_name_for_value.clone();
FieldFuture::new(async move {
if let Some(parent) = ctx.parent_value.downcast_ref::<DynamicMessage>() {
let value = parent.get_field(&field_desc);
let gql_value =
prost_value_to_graphql(&value, Some(&field_desc), &field_ext)
.map_err(|e| async_graphql::Error::new(e.to_string()))?;
return Ok(Some(gql_value));
}
if let Some(GqlValue::Object(map)) = ctx.parent_value.as_value() {
if let Some(val) = map.get(&Name::new(field_name_for_value.clone())) {
return Ok(Some(val.clone()));
}
}
Ok(Some(GqlValue::Null))
})
});
if field_is_external(&field, &field_ext) {
gql_field = gql_field.directive(async_graphql::dynamic::Directive::new("external"));
}
if let Some(requires) = field_requires(&field, &field_ext) {
gql_field = gql_field.directive(
async_graphql::dynamic::Directive::new("requires")
.argument("fields", GqlValue::String(requires)),
);
}
if let Some(provides) = field_provides(&field, &field_ext) {
gql_field = gql_field.directive(
async_graphql::dynamic::Directive::new("provides")
.argument("fields", GqlValue::String(provides)),
);
}
if field_is_shareable(&field, &field_ext) {
gql_field =
gql_field.directive(async_graphql::dynamic::Directive::new("shareable"));
}
if let Some(desc) = field_description(&field, &field_ext) {
gql_field = gql_field.description(desc);
}
obj = obj.field(gql_field);
}
self.objects.insert(name.clone(), obj);
TypeRef::named(name)
}
fn input_type_for_field(
&mut self,
field: &FieldDescriptor,
field_ext: &ExtensionDescriptor,
required: bool,
) -> TypeRef {
type_for_field(self, field, field_ext, required, true)
}
fn output_type_for_field(
&mut self,
field: &FieldDescriptor,
field_ext: &ExtensionDescriptor,
required: bool,
) -> TypeRef {
type_for_field(self, field, field_ext, required, false)
}
}
fn type_for_field(
registry: &mut TypeRegistry,
field: &FieldDescriptor,
field_ext: &ExtensionDescriptor,
required: bool,
is_input: bool,
) -> TypeRef {
let base = match field.kind() {
Kind::Bool => TypeRef::named(TypeRef::BOOLEAN),
Kind::String => TypeRef::named(TypeRef::STRING),
Kind::Bytes => {
if is_input {
TypeRef::named(TypeRef::UPLOAD)
} else {
TypeRef::named(TypeRef::STRING)
}
}
Kind::Float | Kind::Double => TypeRef::named(TypeRef::FLOAT),
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 | Kind::Uint32 | Kind::Fixed32 => {
TypeRef::named(TypeRef::INT)
}
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint64 | Kind::Fixed64 => {
TypeRef::named(TypeRef::STRING)
}
Kind::Enum(en) => registry.ensure_enum(&en),
Kind::Message(msg) => {
if is_input {
registry.ensure_input_object(&msg, field_ext)
} else {
registry.ensure_object(&msg, field_ext)
}
}
};
let base = if required {
TypeRef::NonNull(Box::new(base))
} else {
base
};
if field.is_list() {
TypeRef::List(Box::new(base))
} else {
base
}
}
fn placeholder_query_root() -> Object {
let ty = TypeRef::NonNull(Box::new(TypeRef::named(TypeRef::BOOLEAN)));
let field = Field::new("_placeholder", ty, |_| {
FieldFuture::new(async { Ok(Some(GqlValue::Boolean(true))) })
});
Object::new("Query").field(field)
}
#[derive(Clone)]
struct ArgumentSpec {
name: String,
ty: TypeRef,
}
impl ArgumentSpec {
fn into_input_value(self) -> InputValue {
InputValue::new(self.name, self.ty)
}
}
#[derive(Clone)]
struct OperationConfig {
service_name: String,
grpc_path: String,
input_desc: MessageDescriptor,
output_desc: MessageDescriptor,
request_wrapper_name: Option<String>,
return_type: TypeRef,
args: Vec<ArgumentSpec>,
schema_opts: GraphqlSchema,
field_ext: ExtensionDescriptor,
}
impl OperationConfig {
fn new(
service: &prost_reflect::ServiceDescriptor,
method: &prost_reflect::MethodDescriptor,
schema_opts: &GraphqlSchema,
field_ext: ExtensionDescriptor,
registry: &mut TypeRegistry,
) -> Result<Self> {
let input_desc = method.input();
let output_desc = method.output();
let return_required = schema_opts
.response
.as_ref()
.map(|resp| resp.required)
.unwrap_or(false);
let return_type = compute_return_type(
&output_desc,
schema_opts.response.as_ref(),
&field_ext,
registry,
return_required,
);
let request_wrapper_name = schema_opts
.request
.as_ref()
.filter(|req| !req.name.is_empty())
.map(|req| req.name.clone());
let args = build_arguments(&input_desc, &request_wrapper_name, &field_ext, registry);
let grpc_path = format!("/{}/{}", service.full_name(), method.name());
let service_name = service.full_name().to_string();
Ok(Self {
service_name,
grpc_path,
input_desc,
output_desc,
request_wrapper_name,
return_type,
args,
schema_opts: schema_opts.clone(),
field_ext,
})
}
}
#[allow(clippy::too_many_arguments)]
fn build_field(
name: String,
service: &prost_reflect::ServiceDescriptor,
method: &prost_reflect::MethodDescriptor,
schema_opts: &GraphqlSchema,
field_ext: ExtensionDescriptor,
registry: &mut TypeRegistry,
client_pool: GrpcClientPool,
header_propagation: Option<HeaderPropagationConfig>,
) -> Result<Field> {
let config = OperationConfig::new(service, method, schema_opts, field_ext, registry)?;
let args = config.args.clone();
let field = Field::new(name, config.return_type.clone(), move |ctx| {
let client_pool = client_pool.clone();
let config = config.clone();
let header_propagation = header_propagation.clone();
FieldFuture::new(async move {
let client = client_pool.get(&config.service_name).ok_or_else(|| {
async_graphql::Error::new(format!("gRPC client {} not found", config.service_name))
})?;
let request_msg = build_request_message(
&config.input_desc,
&ctx,
&config.request_wrapper_name,
&config.field_ext,
)?;
let cache_key =
GrpcCacheKey::new(&config.service_name, &config.grpc_path, &request_msg).map_err(
|e| async_graphql::Error::new(format!("failed to encode request: {e}")),
)?;
if let Some(cache) = ctx.data_opt::<GrpcResponseCache>() {
if let Some(val) = cache.get(&cache_key) {
return Ok(Some(val));
}
}
let mut grpc = Grpc::new(client.channel());
grpc.ready()
.await
.map_err(|e| async_graphql::Error::new(format!("gRPC not ready: {e}")))?;
let codec = ReflectCodec::new(config.output_desc.clone());
let path: http::uri::PathAndQuery = config
.grpc_path
.parse()
.map_err(|e| async_graphql::Error::new(format!("invalid gRPC path: {e}")))?;
let mut grpc_request = tonic::Request::new(request_msg);
if let Some(ref propagation_config) = header_propagation {
if let Some(middleware_ctx) = ctx.data_opt::<crate::middleware::Context>() {
let metadata = propagation_config.extract_metadata(&middleware_ctx.headers);
grpc_request =
crate::headers::apply_metadata_to_request(grpc_request, metadata);
}
}
if let Some(plugins) = ctx.data_opt::<crate::plugin::PluginRegistry>() {
plugins
.on_subgraph_request(&config.service_name, grpc_request.metadata_mut())
.await
.map_err(|e| async_graphql::Error::new(e.to_string()))?;
}
let response = grpc
.unary(grpc_request, path, codec)
.await
.map_err(|e| async_graphql::Error::new(format!("gRPC error: {e}")))?
.into_inner();
let value = apply_response_pluck(&response, &config.schema_opts, &config.field_ext)?;
if let Some(cache) = ctx.data_opt::<GrpcResponseCache>() {
cache.insert(cache_key, value.clone());
}
Ok(Some(value))
})
});
let mut field = field;
for arg in args {
field = field.argument(arg.into_input_value());
}
Ok(field)
}
#[allow(clippy::too_many_arguments)]
fn build_subscription_field(
name: String,
service: &prost_reflect::ServiceDescriptor,
method: &prost_reflect::MethodDescriptor,
schema_opts: &GraphqlSchema,
field_ext: ExtensionDescriptor,
registry: &mut TypeRegistry,
client_pool: GrpcClientPool,
header_propagation: Option<HeaderPropagationConfig>,
) -> Result<SubscriptionField> {
let config = OperationConfig::new(service, method, schema_opts, field_ext, registry)?;
let args = config.args.clone();
let field = SubscriptionField::new(name, config.return_type.clone(), move |ctx| {
let client_pool = client_pool.clone();
let config = config.clone();
let header_propagation = header_propagation.clone();
SubscriptionFieldFuture::new(async move {
let client = client_pool.get(&config.service_name).ok_or_else(|| {
async_graphql::Error::new(format!("gRPC client {} not found", config.service_name))
})?;
let request_msg = build_request_message(
&config.input_desc,
&ctx,
&config.request_wrapper_name,
&config.field_ext,
)?;
let mut grpc = Grpc::new(client.channel());
grpc.ready()
.await
.map_err(|e| async_graphql::Error::new(format!("gRPC not ready: {e}")))?;
let codec = ReflectCodec::new(config.output_desc.clone());
let path: http::uri::PathAndQuery = config
.grpc_path
.parse()
.map_err(|e| async_graphql::Error::new(format!("invalid gRPC path: {e}")))?;
let mut grpc_request = tonic::Request::new(request_msg);
if let Some(ref propagation_config) = header_propagation {
if let Some(middleware_ctx) = ctx.data_opt::<crate::middleware::Context>() {
let metadata = propagation_config.extract_metadata(&middleware_ctx.headers);
grpc_request =
crate::headers::apply_metadata_to_request(grpc_request, metadata);
}
}
let response = grpc
.server_streaming(grpc_request, path, codec)
.await
.map_err(|e| async_graphql::Error::new(format!("gRPC error: {e}")))?;
let schema_opts = config.schema_opts.clone();
let field_ext = config.field_ext.clone();
let stream = response.into_inner().map(move |item| {
let msg =
item.map_err(|e| async_graphql::Error::new(format!("gRPC error: {e}")))?;
apply_response_pluck(&msg, &schema_opts, &field_ext)
.map_err(|e| async_graphql::Error::new(e.to_string()))
});
Ok(stream)
})
});
let mut field = field;
for arg in args {
field = field.argument(arg.into_input_value());
}
Ok(field)
}
fn build_arguments(
input_desc: &MessageDescriptor,
wrapper: &Option<String>,
field_ext: &ExtensionDescriptor,
registry: &mut TypeRegistry,
) -> Vec<ArgumentSpec> {
if let Some(wrapper_name) = wrapper {
let ty = registry.ensure_input_object(input_desc, field_ext);
return vec![ArgumentSpec {
name: wrapper_name.clone(),
ty,
}];
}
let mut args = Vec::new();
for field in input_desc.fields() {
if field_is_omitted(&field, field_ext) {
continue;
}
let required = field_is_required(&field, field_ext);
let ty = registry.input_type_for_field(&field, field_ext, required);
let arg_name = graphql_field_name(&field, field_ext);
args.push(ArgumentSpec { name: arg_name, ty });
}
args
}
fn field_is_omitted(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> bool {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.map(|f| f.omit)
.unwrap_or(false)
}
fn field_is_required(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> bool {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.map(|f| f.required)
.unwrap_or(false)
}
fn graphql_field_name(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> String {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.and_then(|f| {
if f.name.is_empty() {
None
} else {
Some(f.name)
}
})
.unwrap_or_else(|| field.name().to_string())
}
fn field_is_external(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> bool {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.map(|f| f.external)
.unwrap_or(false)
}
fn field_requires(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> Option<String> {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.and_then(|f| {
if f.requires.is_empty() {
None
} else {
Some(f.requires)
}
})
}
fn field_provides(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> Option<String> {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.and_then(|f| {
if f.provides.is_empty() {
None
} else {
Some(f.provides)
}
})
}
fn field_is_shareable(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> bool {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.map(|f| f.shareable)
.unwrap_or(false)
}
fn field_description(field: &FieldDescriptor, field_ext: &ExtensionDescriptor) -> Option<String> {
decode_extension::<GraphqlField>(&field.options(), field_ext)
.ok()
.flatten()
.and_then(|f| {
if f.description.is_empty() {
None
} else {
Some(f.description)
}
})
}
fn compute_return_type(
output_desc: &MessageDescriptor,
response_opts: Option<&GraphqlResponse>,
field_ext: &ExtensionDescriptor,
registry: &mut TypeRegistry,
required: bool,
) -> TypeRef {
let pluck_field = response_opts.and_then(|resp| {
if resp.pluck.is_empty() {
None
} else {
Some(resp.pluck.as_str())
}
});
let base = if let Some(field_name) = pluck_field {
if let Some(field_desc) = output_desc.get_field_by_name(field_name) {
registry.output_type_for_field(&field_desc, field_ext, false)
} else {
registry.ensure_object(output_desc, field_ext)
}
} else {
registry.ensure_object(output_desc, field_ext)
};
if required {
TypeRef::NonNull(Box::new(base))
} else {
base
}
}
fn build_request_message(
desc: &MessageDescriptor,
ctx: &ResolverContext<'_>,
wrapper: &Option<String>,
field_ext: &ExtensionDescriptor,
) -> async_graphql::Result<DynamicMessage> {
if let Some(wrapper_name) = wrapper {
let wrapper_value = ctx
.args
.get(wrapper_name)
.ok_or_else(|| async_graphql::Error::new(format!("missing argument {wrapper_name}")))?;
if let GqlValue::Object(obj) = wrapper_value.as_value() {
return object_to_message(desc, obj, ctx, field_ext);
} else {
return Err(async_graphql::Error::new(format!(
"argument {wrapper_name} must be an object"
)));
}
}
let mut message = DynamicMessage::new(desc.clone());
for field in desc.fields() {
if field_is_omitted(&field, field_ext) {
continue;
}
let arg_name = graphql_field_name(&field, field_ext);
if let Some(value) = ctx.args.get(&arg_name) {
let prost_value = graphql_input_to_prost(value.as_value(), &field, ctx, field_ext)?;
message.set_field(&field, prost_value);
}
}
Ok(message)
}
fn object_to_message(
desc: &MessageDescriptor,
values: &IndexMap<Name, GqlValue>,
ctx: &ResolverContext<'_>,
field_ext: &ExtensionDescriptor,
) -> async_graphql::Result<DynamicMessage> {
let mut message = DynamicMessage::new(desc.clone());
for field in desc.fields() {
if field_is_omitted(&field, field_ext) {
continue;
}
let arg_name = graphql_field_name(&field, field_ext);
if let Some(value) = values.get(&Name::new(arg_name)) {
let prost_value = graphql_input_to_prost(value, &field, ctx, field_ext)?;
message.set_field(&field, prost_value);
}
}
Ok(message)
}
fn graphql_input_to_prost(
value: &GqlValue,
field: &FieldDescriptor,
ctx: &ResolverContext<'_>,
field_ext: &ExtensionDescriptor,
) -> async_graphql::Result<Value> {
if field.is_list() {
let list = match value {
GqlValue::List(list) => list,
_ => return Err(async_graphql::Error::new("expected list")),
};
let mut items = Vec::new();
for v in list {
items.push(single_input_to_prost(v, field, ctx, field_ext)?);
}
return Ok(Value::List(items));
}
single_input_to_prost(value, field, ctx, field_ext)
}
fn single_input_to_prost(
value: &GqlValue,
field: &FieldDescriptor,
ctx: &ResolverContext<'_>,
field_ext: &ExtensionDescriptor,
) -> async_graphql::Result<Value> {
let kind = field.kind();
match kind {
Kind::Bool => match value {
GqlValue::Boolean(b) => Ok(Value::Bool(*b)),
_ => Err(async_graphql::Error::new("expected boolean")),
},
Kind::String => match value {
GqlValue::String(s) => Ok(Value::String(s.clone())),
_ => Err(async_graphql::Error::new("expected string")),
},
Kind::Bytes => match value {
GqlValue::String(s) => {
if let Some(bytes) = upload_marker_to_bytes(ctx, s)? {
return Ok(Value::Bytes(bytes.into()));
}
BASE64
.decode(s)
.map(|b| Value::Bytes(b.into()))
.map_err(|e| async_graphql::Error::new(format!("invalid base64: {e}")))
}
GqlValue::Binary(b) => Ok(Value::Bytes(b.to_vec().into())),
_ => Err(async_graphql::Error::new(
"expected upload or base64 string",
)),
},
Kind::Float | Kind::Double => match value {
GqlValue::Number(n) => n
.as_f64()
.map(Value::F64)
.ok_or_else(|| async_graphql::Error::new("expected float")),
_ => Err(async_graphql::Error::new("expected float")),
},
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => match value {
GqlValue::Number(n) => n
.as_i64()
.map(|v| Value::I32(v as i32))
.ok_or_else(|| async_graphql::Error::new("expected int")),
_ => Err(async_graphql::Error::new("expected int")),
},
Kind::Uint32 | Kind::Fixed32 => match value {
GqlValue::Number(n) => n
.as_u64()
.map(|v| Value::U32(v as u32))
.ok_or_else(|| async_graphql::Error::new("expected unsigned int")),
_ => Err(async_graphql::Error::new("expected unsigned int")),
},
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => match value {
GqlValue::String(s) => s
.parse::<i64>()
.map(Value::I64)
.map_err(|e| async_graphql::Error::new(format!("invalid i64: {e}"))),
GqlValue::Number(n) => n
.as_i64()
.map(Value::I64)
.ok_or_else(|| async_graphql::Error::new("expected 64-bit int")),
_ => Err(async_graphql::Error::new("expected 64-bit int (as string)")),
},
Kind::Uint64 | Kind::Fixed64 => match value {
GqlValue::String(s) => s
.parse::<u64>()
.map(Value::U64)
.map_err(|e| async_graphql::Error::new(format!("invalid u64: {e}"))),
GqlValue::Number(n) => n
.as_u64()
.map(Value::U64)
.ok_or_else(|| async_graphql::Error::new("expected 64-bit uint")),
_ => Err(async_graphql::Error::new(
"expected 64-bit uint (as string)",
)),
},
Kind::Enum(en) => {
let name_str = match value {
GqlValue::Enum(name) => Some(name.as_str()),
GqlValue::String(name) => Some(name.as_str()),
_ => None,
}
.ok_or_else(|| async_graphql::Error::new("expected enum value"))?;
let num = en
.get_value_by_name(name_str)
.map(|v| v.number())
.ok_or_else(|| async_graphql::Error::new("invalid enum value"))?;
Ok(Value::EnumNumber(num))
}
Kind::Message(msg) => match value {
GqlValue::Object(obj) => {
object_to_message(&msg, obj, ctx, field_ext).map(Value::Message)
}
_ => Err(async_graphql::Error::new("expected object")),
},
}
}
const UPLOAD_REF_PREFIX: &str = "#__graphql_file__:";
fn upload_marker_to_bytes(
ctx: &ResolverContext<'_>,
marker: &str,
) -> async_graphql::Result<Option<Vec<u8>>> {
let Some(rest) = marker.strip_prefix(UPLOAD_REF_PREFIX) else {
return Ok(None);
};
let idx = rest.parse::<usize>().map_err(|e| {
async_graphql::Error::new(format!("invalid upload reference {marker}: {e}"))
})?;
let upload: &UploadValue = ctx
.query_env
.uploads
.get(idx)
.ok_or_else(|| async_graphql::Error::new(format!("upload index {idx} not found")))?;
let reader = upload
.try_clone()
.map_err(|e| async_graphql::Error::new(format!("failed to clone upload: {e}")))?;
let mut bytes = Vec::new();
reader
.into_read()
.read_to_end(&mut bytes)
.map_err(|e| async_graphql::Error::new(format!("failed to read upload: {e}")))?;
Ok(Some(bytes))
}
fn prost_value_to_graphql(
value: &Value,
field: Option<&FieldDescriptor>,
field_ext: &ExtensionDescriptor,
) -> Result<GqlValue> {
let kind = field.map(|f| f.kind());
Ok(match value {
Value::Bool(b) => GqlValue::Boolean(*b),
Value::I32(v) => GqlValue::from(*v),
Value::I64(v) => GqlValue::from(*v),
Value::U32(v) => GqlValue::from(*v as i64),
Value::U64(v) => GqlValue::from(*v as i64),
Value::F32(v) => GqlValue::from(*v),
Value::F64(v) => GqlValue::from(*v),
Value::String(s) => GqlValue::from(s.clone()),
Value::Bytes(b) => GqlValue::from(BASE64.encode(b)),
Value::EnumNumber(num) => {
if let Some(Kind::Enum(en)) = kind {
if let Some(val) = en.get_value(*num) {
GqlValue::Enum(Name::new(val.name()))
} else {
GqlValue::from(*num)
}
} else {
GqlValue::from(*num)
}
}
Value::Message(msg) => {
dynamic_message_to_value(msg, field_ext).unwrap_or_else(|_| GqlValue::Null)
}
Value::List(list) => {
let items = list
.iter()
.map(|v| prost_value_to_graphql(v, field, field_ext))
.collect::<Result<Vec<_>>>()?;
GqlValue::List(items)
}
Value::Map(map) => {
let mut obj = IndexMap::new();
for (k, v) in map {
obj.insert(
Name::new(map_key_to_string(k)),
prost_value_to_graphql(v, field, field_ext)?,
);
}
GqlValue::Object(obj)
}
})
}
fn dynamic_message_to_value(
message: &DynamicMessage,
field_ext: &ExtensionDescriptor,
) -> Result<GqlValue> {
let mut map = IndexMap::new();
for field_desc in message.descriptor().fields() {
if field_is_omitted(&field_desc, field_ext) {
continue;
}
let value = message.get_field(&field_desc);
let gql_value = prost_value_to_graphql(&value, Some(&field_desc), field_ext)?;
let name = graphql_field_name(&field_desc, field_ext);
map.insert(Name::new(name), gql_value);
}
Ok(GqlValue::Object(map))
}
fn apply_response_pluck(
response: &DynamicMessage,
schema_opts: &GraphqlSchema,
field_ext: &ExtensionDescriptor,
) -> Result<GqlValue> {
if let Some(resp_opts) = &schema_opts.response {
if !resp_opts.pluck.is_empty() {
if let Some(field_desc) = response.descriptor().get_field_by_name(&resp_opts.pluck) {
let val = response.get_field(&field_desc);
return prost_value_to_graphql(&val, Some(&field_desc), field_ext);
}
}
}
dynamic_message_to_value(response, field_ext)
}
fn map_key_to_string(key: &MapKey) -> String {
match key {
MapKey::Bool(b) => b.to_string(),
MapKey::I32(v) => v.to_string(),
MapKey::I64(v) => v.to_string(),
MapKey::U32(v) => v.to_string(),
MapKey::U64(v) => v.to_string(),
MapKey::String(s) => s.clone(),
}
}
fn decode_extension<T: Message + Default>(
opts: &DynamicMessage,
ext: &ExtensionDescriptor,
) -> Result<Option<T>> {
if !opts.has_extension(ext) {
return Ok(None);
}
let val = opts.get_extension(ext);
if let Value::Message(msg) = val.as_ref() {
T::decode(msg.encode_to_vec().as_slice())
.map(Some)
.map_err(|e| Error::Schema(format!("failed to decode extension: {e}")))
} else {
Ok(None)
}
}
struct ReflectCodec {
response_desc: MessageDescriptor,
}
impl ReflectCodec {
fn new(response_desc: MessageDescriptor) -> Self {
Self { response_desc }
}
}
impl Codec for ReflectCodec {
type Encode = DynamicMessage;
type Decode = DynamicMessage;
type Encoder = ReflectEncoder;
type Decoder = ReflectDecoder;
fn encoder(&mut self) -> Self::Encoder {
ReflectEncoder
}
fn decoder(&mut self) -> Self::Decoder {
ReflectDecoder {
desc: self.response_desc.clone(),
}
}
}
struct ReflectDecoder {
desc: MessageDescriptor,
}
struct ReflectEncoder;
impl Decoder for ReflectDecoder {
type Item = DynamicMessage;
type Error = Status;
fn decode(
&mut self,
buf: &mut DecodeBuf<'_>,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
if !Buf::has_remaining(buf) {
return Ok(None);
}
let len = Buf::remaining(buf);
let bytes = buf.copy_to_bytes(len);
DynamicMessage::decode(self.desc.clone(), bytes)
.map(Some)
.map_err(|e| Status::internal(format!("decode error: {e}")))
}
}
impl Encoder for ReflectEncoder {
type Item = DynamicMessage;
type Error = Status;
fn encode(
&mut self,
item: Self::Item,
buf: &mut EncodeBuf<'_>,
) -> std::result::Result<(), Self::Error> {
item.encode(buf)
.map_err(|e| Status::internal(format!("encode error: {e}")))
}
}
#[cfg(test)]
mod builder_tests {
use super::*;
#[test]
fn test_schema_builder_default() {
let builder = SchemaBuilder::default();
assert_eq!(builder.descriptor_count(), 0);
}
#[test]
fn test_schema_builder_with_descriptor_bytes() {
let builder = SchemaBuilder::new().with_descriptor_set_bytes(vec![1, 2, 3]);
assert_eq!(builder.descriptor_count(), 1);
}
#[test]
fn test_schema_builder_add_descriptor_bytes() {
let builder = SchemaBuilder::new()
.with_descriptor_set_bytes(vec![1])
.add_descriptor_set_bytes(vec![2]);
assert_eq!(builder.descriptor_count(), 2);
}
#[test]
fn test_schema_builder_enable_federation() {
let builder = SchemaBuilder::new().enable_federation();
assert!(builder.federation);
}
#[test]
fn test_schema_builder_disable_introspection() {
let builder = SchemaBuilder::new().disable_introspection();
assert!(builder.introspection_disabled);
}
#[test]
fn test_schema_builder_depth_limit() {
let builder = SchemaBuilder::new().with_query_depth_limit(5);
assert_eq!(builder.query_depth_limit, Some(5));
}
#[test]
fn test_schema_builder_complexity_limit() {
let builder = SchemaBuilder::new().with_query_complexity_limit(100);
assert_eq!(builder.query_complexity_limit, Some(100));
}
#[test]
fn test_schema_builder_service_allowlist() {
let builder = SchemaBuilder::new()
.with_services(vec!["ServiceA".to_string(), "ServiceB".to_string()]);
let allowlist = builder.service_allowlist.unwrap();
assert!(allowlist.contains("ServiceA"));
assert!(allowlist.contains("ServiceB"));
assert_eq!(allowlist.len(), 2);
}
#[test]
fn test_schema_builder_build_empty_fails() {
let pool = GrpcClientPool::new();
let builder = SchemaBuilder::new();
let result = builder.build(&pool);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("at least one descriptor set"));
}
}
#[test]
fn test_schema_builder_with_entity_resolver() {
let resolver = GrpcEntityResolver::new(GrpcClientPool::new());
let builder = SchemaBuilder::new().with_entity_resolver(Arc::new(resolver));
assert!(builder.entity_resolver.is_some());
}
#[test]
fn test_schema_builder_invalid_descriptor_bytes() {
let pool = GrpcClientPool::new();
let builder = SchemaBuilder::new().with_descriptor_set_bytes(vec![0xFF, 0xFF, 0xFF]);
let result = builder.build(&pool);
assert!(result.is_err());
}
}