use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::RwLock;
use apollo_compiler::ast;
use futures::prelude::*;
use reqwest::Client;
use serde::Deserialize;
use serde::Serialize;
use tokio::fs::read_to_string;
use tokio::sync::mpsc;
use tower::BoxError;
use crate::Configuration;
use crate::uplink::UplinkConfig;
use crate::uplink::persisted_queries_manifest_stream::MaybePersistedQueriesManifestChunks;
use crate::uplink::persisted_queries_manifest_stream::PersistedQueriesManifestChunk;
use crate::uplink::persisted_queries_manifest_stream::PersistedQueriesManifestQuery;
use crate::uplink::stream_from_uplink_transforming_new_response;
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct FullPersistedQueryOperationId {
pub operation_id: String,
pub client_name: Option<String>,
}
pub type PersistedQueryManifest = HashMap<FullPersistedQueryOperationId, String>;
pub(crate) struct FreeformGraphQLAction {
pub(crate) should_allow: bool,
pub(crate) should_log: bool,
}
#[derive(Debug)]
pub(crate) enum FreeformGraphQLBehavior {
AllowAll {
apq_enabled: bool,
},
DenyAll {
log_unknown: bool,
},
AllowIfInSafelist {
safelist: FreeformGraphQLSafelist,
log_unknown: bool,
},
LogUnlessInSafelist {
safelist: FreeformGraphQLSafelist,
apq_enabled: bool,
},
}
impl FreeformGraphQLBehavior {
fn action_for_freeform_graphql(
&self,
ast: Result<&ast::Document, &str>,
) -> FreeformGraphQLAction {
match self {
FreeformGraphQLBehavior::AllowAll { .. } => FreeformGraphQLAction {
should_allow: true,
should_log: false,
},
FreeformGraphQLBehavior::DenyAll { log_unknown, .. } => FreeformGraphQLAction {
should_allow: false,
should_log: *log_unknown,
},
FreeformGraphQLBehavior::AllowIfInSafelist {
safelist,
log_unknown,
..
} => {
if safelist.is_allowed(ast) {
FreeformGraphQLAction {
should_allow: true,
should_log: false,
}
} else {
FreeformGraphQLAction {
should_allow: false,
should_log: *log_unknown,
}
}
}
FreeformGraphQLBehavior::LogUnlessInSafelist { safelist, .. } => {
FreeformGraphQLAction {
should_allow: true,
should_log: !safelist.is_allowed(ast),
}
}
}
}
}
#[derive(Debug)]
pub(crate) struct FreeformGraphQLSafelist {
normalized_bodies: HashSet<String>,
}
impl FreeformGraphQLSafelist {
fn new(manifest: &PersistedQueryManifest) -> Self {
let mut safelist = Self {
normalized_bodies: HashSet::new(),
};
for body in manifest.values() {
safelist.insert_from_manifest(body);
}
safelist
}
fn insert_from_manifest(&mut self, body_from_manifest: &str) {
self.normalized_bodies.insert(
self.normalize_body(
ast::Document::parse(body_from_manifest, "from_manifest")
.as_ref()
.map_err(|_| body_from_manifest),
),
);
}
fn is_allowed(&self, ast: Result<&ast::Document, &str>) -> bool {
self.normalized_bodies.contains(&self.normalize_body(ast))
}
fn normalize_body(&self, ast: Result<&ast::Document, &str>) -> String {
match ast {
Err(body_from_request) => {
body_from_request.to_string()
}
Ok(ast) => {
let mut operations = vec![];
let mut fragments = vec![];
for definition in &ast.definitions {
match definition {
ast::Definition::OperationDefinition(def) => operations.push(def.clone()),
ast::Definition::FragmentDefinition(def) => fragments.push(def.clone()),
_ => {}
}
}
let mut new_document = ast::Document::new();
operations.sort_by_key(|x| x.name.clone());
new_document
.definitions
.extend(operations.into_iter().map(Into::into));
fragments.sort_by_key(|x| x.name.clone());
new_document
.definitions
.extend(fragments.into_iter().map(Into::into));
new_document.to_string()
}
}
}
}
#[derive(Debug)]
pub(crate) struct PersistedQueryManifestPollerState {
persisted_query_manifest: PersistedQueryManifest,
pub(crate) freeform_graphql_behavior: FreeformGraphQLBehavior,
}
#[derive(Debug)]
pub(crate) struct PersistedQueryManifestPoller {
pub(crate) state: Arc<RwLock<PersistedQueryManifestPollerState>>,
_drop_signal: mpsc::Sender<()>,
}
impl PersistedQueryManifestPoller {
pub(crate) async fn new(config: Configuration) -> Result<Self, BoxError> {
if let Some(manifest_files) = config.persisted_queries.local_manifests {
if manifest_files.is_empty() {
return Err("no local persisted query list files specified".into());
}
let mut manifest = PersistedQueryManifest::new();
for local_pq_list in manifest_files {
tracing::info!(
"Loading persisted query list from local file: {}",
local_pq_list
);
let local_manifest: String =
read_to_string(local_pq_list.clone())
.await
.map_err(|e| -> BoxError {
format!(
"could not read local persisted query list file {}: {}",
local_pq_list, e
)
.into()
})?;
let manifest_file: SignedUrlChunk =
serde_json::from_str(&local_manifest).map_err(|e| -> BoxError {
format!(
"could not parse local persisted query list file {}: {}",
local_pq_list.clone(),
e
)
.into()
})?;
if manifest_file.format != "apollo-persisted-query-manifest" {
return Err("chunk format is not 'apollo-persisted-query-manifest'".into());
}
if manifest_file.version != 1 {
return Err("persisted query manifest chunk version is not 1".into());
}
for operation in manifest_file.operations {
manifest.insert(
FullPersistedQueryOperationId {
operation_id: operation.id,
client_name: operation.client_name,
},
operation.body,
);
}
}
let freeform_graphql_behavior = if config.persisted_queries.safelist.enabled {
if config.persisted_queries.safelist.require_id {
FreeformGraphQLBehavior::DenyAll {
log_unknown: config.persisted_queries.log_unknown,
}
} else {
FreeformGraphQLBehavior::AllowIfInSafelist {
safelist: FreeformGraphQLSafelist::new(&manifest),
log_unknown: config.persisted_queries.log_unknown,
}
}
} else if config.persisted_queries.log_unknown {
FreeformGraphQLBehavior::LogUnlessInSafelist {
safelist: FreeformGraphQLSafelist::new(&manifest),
apq_enabled: config.apq.enabled,
}
} else {
FreeformGraphQLBehavior::AllowAll {
apq_enabled: config.apq.enabled,
}
};
let state = Arc::new(RwLock::new(PersistedQueryManifestPollerState {
persisted_query_manifest: manifest.clone(),
freeform_graphql_behavior,
}));
tracing::info!(
"Loaded {} persisted queries from local file.",
manifest.len()
);
Ok(Self {
state,
_drop_signal: mpsc::channel::<()>(1).0,
})
} else if let Some(uplink_config) = config.uplink.as_ref() {
let state = Arc::new(RwLock::new(PersistedQueryManifestPollerState {
persisted_query_manifest: PersistedQueryManifest::new(),
freeform_graphql_behavior: FreeformGraphQLBehavior::DenyAll { log_unknown: false },
}));
let http_client = Client::builder().timeout(uplink_config.timeout).gzip(true).build()
.map_err(|e| -> BoxError {
format!(
"could not initialize HTTP client for fetching persisted queries manifest chunks: {}",
e
).into()
})?;
let (_drop_signal, drop_receiver) = mpsc::channel::<()>(1);
let (ready_sender, mut ready_receiver) =
mpsc::channel::<ManifestPollResultOnStartup>(1);
tokio::task::spawn(poll_uplink(
uplink_config.clone(),
state.clone(),
config,
ready_sender,
drop_receiver,
http_client,
));
match ready_receiver.recv().await {
Some(startup_result) => match startup_result {
ManifestPollResultOnStartup::LoadedOperations => (),
ManifestPollResultOnStartup::Err(error) => return Err(error),
},
None => {
return Err("could not receive ready event for persisted query layer".into());
}
}
Ok(Self {
state,
_drop_signal,
})
} else {
Err("persisted queries requires Apollo GraphOS. ensure that you have set APOLLO_KEY and APOLLO_GRAPH_REF environment variables".into())
}
}
pub(crate) fn get_operation_body(
&self,
persisted_query_id: &str,
client_name: Option<String>,
) -> Option<String> {
let state = self
.state
.read()
.expect("could not acquire read lock on persisted query manifest state");
if let Some(body) = state
.persisted_query_manifest
.get(&FullPersistedQueryOperationId {
operation_id: persisted_query_id.to_string(),
client_name: client_name.clone(),
})
.cloned()
{
Some(body)
} else if client_name.is_some() {
state
.persisted_query_manifest
.get(&FullPersistedQueryOperationId {
operation_id: persisted_query_id.to_string(),
client_name: None,
})
.cloned()
} else {
None
}
}
pub(crate) fn get_all_operations(&self) -> Vec<String> {
let state = self
.state
.read()
.expect("could not acquire read lock on persisted query manifest state");
state.persisted_query_manifest.values().cloned().collect()
}
pub(crate) fn action_for_freeform_graphql(
&self,
ast: Result<&ast::Document, &str>,
) -> FreeformGraphQLAction {
let state = self
.state
.read()
.expect("could not acquire read lock on persisted query state");
state
.freeform_graphql_behavior
.action_for_freeform_graphql(ast)
}
pub(crate) fn never_allows_freeform_graphql(&self) -> Option<bool> {
let state = self
.state
.read()
.expect("could not acquire read lock on persisted query state");
if let FreeformGraphQLBehavior::DenyAll { log_unknown } = state.freeform_graphql_behavior {
Some(log_unknown)
} else {
None
}
}
pub(crate) fn augmenting_apq_with_pre_registration_and_no_safelisting(&self) -> bool {
let state = self
.state
.read()
.expect("could not acquire read lock on persisted query state");
match state.freeform_graphql_behavior {
FreeformGraphQLBehavior::AllowAll { apq_enabled, .. }
| FreeformGraphQLBehavior::LogUnlessInSafelist { apq_enabled, .. } => apq_enabled,
_ => false,
}
}
}
async fn poll_uplink(
uplink_config: UplinkConfig,
state: Arc<RwLock<PersistedQueryManifestPollerState>>,
config: Configuration,
ready_sender: mpsc::Sender<ManifestPollResultOnStartup>,
mut drop_receiver: mpsc::Receiver<()>,
http_client: Client,
) {
let http_client = http_client.clone();
let mut uplink_executor = stream::select_all(vec![
stream_from_uplink_transforming_new_response::<
PersistedQueriesManifestQuery,
MaybePersistedQueriesManifestChunks,
Option<PersistedQueryManifest>,
>(uplink_config.clone(), move |response| {
let http_client = http_client.clone();
Box::new(Box::pin(async move {
match response {
Some(chunks) => manifest_from_chunks(chunks, http_client)
.await
.map(Some)
.map_err(|err| {
format!("could not download persisted query lists: {}", err).into()
}),
None => Ok(None),
}
}))
})
.map(|res| match res {
Ok(Some(new_manifest)) => ManifestPollEvent::NewManifest(new_manifest),
Ok(None) => ManifestPollEvent::NoPersistedQueryList {
graph_ref: uplink_config.apollo_graph_ref.clone(),
},
Err(e) => ManifestPollEvent::Err(e.into()),
})
.boxed(),
drop_receiver
.recv()
.into_stream()
.filter_map(|res| {
future::ready(match res {
None => Some(ManifestPollEvent::Shutdown),
Some(()) => Some(ManifestPollEvent::Err(
"received message on drop channel in persisted query layer, which never \
gets sent"
.into(),
)),
})
})
.boxed(),
])
.take_while(|msg| future::ready(!matches!(msg, ManifestPollEvent::Shutdown)))
.boxed();
let mut ready_sender_once = Some(ready_sender);
while let Some(event) = uplink_executor.next().await {
match event {
ManifestPollEvent::NewManifest(new_manifest) => {
let freeform_graphql_behavior = if config.persisted_queries.safelist.enabled {
if config.persisted_queries.safelist.require_id {
FreeformGraphQLBehavior::DenyAll {
log_unknown: config.persisted_queries.log_unknown,
}
} else {
FreeformGraphQLBehavior::AllowIfInSafelist {
safelist: FreeformGraphQLSafelist::new(&new_manifest),
log_unknown: config.persisted_queries.log_unknown,
}
}
} else if config.persisted_queries.log_unknown {
FreeformGraphQLBehavior::LogUnlessInSafelist {
safelist: FreeformGraphQLSafelist::new(&new_manifest),
apq_enabled: config.apq.enabled,
}
} else {
FreeformGraphQLBehavior::AllowAll {
apq_enabled: config.apq.enabled,
}
};
let new_state = PersistedQueryManifestPollerState {
persisted_query_manifest: new_manifest,
freeform_graphql_behavior,
};
state
.write()
.map(|mut locked_state| {
*locked_state = new_state;
})
.expect("could not acquire write lock on persisted query manifest state");
send_startup_event_or_log_error(
&mut ready_sender_once,
ManifestPollResultOnStartup::LoadedOperations,
)
.await;
}
ManifestPollEvent::Err(e) => {
send_startup_event_or_log_error(
&mut ready_sender_once,
ManifestPollResultOnStartup::Err(e),
)
.await
}
ManifestPollEvent::NoPersistedQueryList { graph_ref } => {
send_startup_event_or_log_error(
&mut ready_sender_once,
ManifestPollResultOnStartup::Err(
format!("no persisted query list found for graph ref {}", &graph_ref)
.into(),
),
)
.await
}
ManifestPollEvent::Shutdown => (),
}
}
async fn send_startup_event_or_log_error(
ready_sender: &mut Option<mpsc::Sender<ManifestPollResultOnStartup>>,
message: ManifestPollResultOnStartup,
) {
match (ready_sender.take(), message) {
(Some(ready_sender), message) => {
if let Err(e) = ready_sender.send(message).await {
tracing::debug!(
"could not send startup event for the persisted query layer: {e}"
);
}
}
(None, ManifestPollResultOnStartup::Err(err)) => {
tracing::error!(
"error while polling uplink for persisted query manifests: {}",
err
)
}
(None, ManifestPollResultOnStartup::LoadedOperations) => {}
}
}
}
async fn manifest_from_chunks(
new_chunks: Vec<PersistedQueriesManifestChunk>,
http_client: Client,
) -> Result<PersistedQueryManifest, BoxError> {
let mut new_persisted_query_manifest = PersistedQueryManifest::new();
tracing::debug!("ingesting new persisted queries: {:?}", &new_chunks);
for new_chunk in new_chunks {
add_chunk_to_operations(
new_chunk,
&mut new_persisted_query_manifest,
http_client.clone(),
)
.await?
}
tracing::info!(
"Loaded {} persisted queries.",
new_persisted_query_manifest.len()
);
Ok(new_persisted_query_manifest)
}
async fn add_chunk_to_operations(
chunk: PersistedQueriesManifestChunk,
operations: &mut PersistedQueryManifest,
http_client: Client,
) -> Result<(), BoxError> {
let mut it = chunk.urls.iter().peekable();
while let Some(chunk_url) = it.next() {
match fetch_chunk(http_client.clone(), chunk_url).await {
Ok(chunk) => {
for operation in chunk.operations {
operations.insert(
FullPersistedQueryOperationId {
operation_id: operation.id,
client_name: operation.client_name,
},
operation.body,
);
}
return Ok(());
}
Err(e) => {
if it.peek().is_some() {
tracing::debug!(
"failed to fetch persisted query list chunk from {}: {}. \
Other endpoints will be tried",
chunk_url,
e
);
continue;
} else {
return Err(e);
}
}
}
}
Err("persisted query chunk did not include any URLs to fetch operations from".into())
}
async fn fetch_chunk(http_client: Client, chunk_url: &String) -> Result<SignedUrlChunk, BoxError> {
let chunk = http_client
.get(chunk_url.clone())
.send()
.await
.and_then(|r| r.error_for_status())
.map_err(|e| -> BoxError {
format!(
"error fetching persisted queries manifest chunk from {}: {}",
chunk_url, e
)
.into()
})?
.json::<SignedUrlChunk>()
.await
.map_err(|e| -> BoxError {
format!(
"error reading body of persisted queries manifest chunk from {}: {}",
chunk_url, e
)
.into()
})?;
if chunk.format != "apollo-persisted-query-manifest" {
return Err("chunk format is not 'apollo-persisted-query-manifest'".into());
}
if chunk.version != 1 {
return Err("persisted query manifest chunk version is not 1".into());
}
Ok(chunk)
}
#[derive(Debug)]
pub(crate) enum ManifestPollEvent {
NewManifest(PersistedQueryManifest),
NoPersistedQueryList { graph_ref: String },
Err(BoxError),
Shutdown,
}
#[derive(Debug)]
pub(crate) enum ManifestPollResultOnStartup {
LoadedOperations,
Err(BoxError),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub(crate) struct SignedUrlChunk {
pub(crate) format: String,
pub(crate) version: u64,
pub(crate) operations: Vec<Operation>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct Operation {
pub(crate) id: String,
pub(crate) body: String,
pub(crate) client_name: Option<String>,
}
#[cfg(test)]
mod tests {
use url::Url;
use super::*;
use crate::configuration::Apq;
use crate::configuration::PersistedQueries;
use crate::test_harness::mocks::persisted_queries::*;
use crate::uplink::Endpoints;
#[tokio::test(flavor = "multi_thread")]
async fn poller_can_get_operation_bodies() {
let (id, body, manifest) = fake_manifest();
let (_mock_guard, uplink_config) = mock_pq_uplink(&manifest).await;
let manifest_manager = PersistedQueryManifestPoller::new(
Configuration::fake_builder()
.uplink(uplink_config)
.build()
.unwrap(),
)
.await
.unwrap();
assert_eq!(manifest_manager.get_operation_body(&id, None), Some(body))
}
#[tokio::test(flavor = "multi_thread")]
async fn poller_wont_start_without_uplink_connection() {
let uplink_endpoint = Url::parse("https://definitely.not.uplink").unwrap();
assert!(
PersistedQueryManifestPoller::new(
Configuration::fake_builder()
.uplink(UplinkConfig::for_tests(Endpoints::fallback(vec![
uplink_endpoint
])))
.build()
.unwrap(),
)
.await
.is_err()
);
}
#[tokio::test(flavor = "multi_thread")]
async fn poller_fails_over_on_gcs_failure() {
let (_mock_server1, url1) = mock_pq_uplink_bad_gcs().await;
let (id, body, manifest) = fake_manifest();
let (_mock_guard2, url2) = mock_pq_uplink_one_endpoint(&manifest, None).await;
let manifest_manager = PersistedQueryManifestPoller::new(
Configuration::fake_builder()
.uplink(UplinkConfig::for_tests(Endpoints::fallback(vec![
url1, url2,
])))
.build()
.unwrap(),
)
.await
.unwrap();
assert_eq!(manifest_manager.get_operation_body(&id, None), Some(body))
}
#[test]
fn safelist_body_normalization() {
let safelist = FreeformGraphQLSafelist::new(&PersistedQueryManifest::from([
(
FullPersistedQueryOperationId {
operation_id: "valid-syntax".to_string(),
client_name: None,
},
"fragment A on T { a } query SomeOp { ...A ...B } fragment,,, B on U{b c } # yeah".to_string(),
),
(
FullPersistedQueryOperationId {
operation_id: "invalid-syntax".to_string(),
client_name: None,
},
"}}}".to_string(),
),
]));
let is_allowed = |body: &str| -> bool {
safelist.is_allowed(ast::Document::parse(body, "").as_ref().map_err(|_| body))
};
assert!(is_allowed(
"fragment A on T { a } query SomeOp { ...A ...B } fragment,,, B on U{b c } # yeah"
));
assert!(is_allowed(
"#comment\n fragment, B on U , { b c } query SomeOp { ...A ...B } fragment \nA on T { a }"
));
assert!(!is_allowed(
"fragment A on T { a } query SomeOp { ...A ...B } fragment,,, B on U{c b } # yeah"
));
assert!(!is_allowed("}}}}"));
assert!(is_allowed("}}}"));
}
#[tokio::test(flavor = "multi_thread")]
async fn uses_local_manifest() {
let (_, body, _) = fake_manifest();
let id = "5678".to_string();
let manifest_manager = PersistedQueryManifestPoller::new(
Configuration::fake_builder()
.apq(Apq::fake_new(Some(false)))
.persisted_query(
PersistedQueries::builder()
.enabled(true)
.local_manifests(vec![
"tests/fixtures/persisted-queries-manifest.json".to_string(),
])
.build(),
)
.build()
.unwrap(),
)
.await
.unwrap();
assert_eq!(manifest_manager.get_operation_body(&id, None), Some(body))
}
}