use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use access_json::JSONQuery;
use http::HeaderValue;
use http::header::ACCEPT;
use http::header::ACCEPT_ENCODING;
use http::header::CONNECTION;
use http::header::CONTENT_ENCODING;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_TYPE;
use http::header::HOST;
use http::header::HeaderName;
use http::header::PROXY_AUTHENTICATE;
use http::header::PROXY_AUTHORIZATION;
use http::header::TE;
use http::header::TRAILER;
use http::header::TRANSFER_ENCODING;
use http::header::UPGRADE;
use regex::Regex;
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::Value;
use tower::BoxError;
use tower::Layer;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower_service::Service;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::plugin::serde::deserialize_header_name;
use crate::plugin::serde::deserialize_header_value;
use crate::plugin::serde::deserialize_json_query;
use crate::plugin::serde::deserialize_option_header_name;
use crate::plugin::serde::deserialize_option_header_value;
use crate::plugin::serde::deserialize_regex;
use crate::register_plugin;
use crate::services::SubgraphRequest;
use crate::services::subgraph;
register_plugin!("apollo", "headers", Headers);
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
struct HeadersLocation {
request: Vec<Operation>,
}
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
enum Operation {
Insert(Insert),
Remove(Remove),
Propagate(Propagate),
}
schemar_fn!(remove_named, String, "Remove a header given a header name");
schemar_fn!(
remove_matching,
String,
"Remove a header given a regex matching against the header name"
);
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case")]
enum Remove {
#[schemars(schema_with = "remove_named")]
#[serde(deserialize_with = "deserialize_header_name")]
Named(HeaderName),
#[schemars(schema_with = "remove_matching")]
#[serde(deserialize_with = "deserialize_regex")]
Matching(Regex),
}
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
#[serde(untagged)]
enum Insert {
Static(InsertStatic),
FromContext(InsertFromContext),
FromBody(InsertFromBody),
}
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
struct InsertStatic {
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_header_name")]
name: HeaderName,
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_header_value")]
value: HeaderValue,
}
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
struct InsertFromContext {
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_header_name")]
name: HeaderName,
from_context: String,
}
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
struct InsertFromBody {
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_header_name")]
name: HeaderName,
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_json_query")]
path: JSONQuery,
#[schemars(with = "Option<String>", default)]
#[serde(deserialize_with = "deserialize_option_header_value")]
default: Option<HeaderValue>,
}
schemar_fn!(
propagate_matching,
String,
"Remove a header given a regex matching header name"
);
#[derive(Clone, JsonSchema, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
#[serde(untagged)]
enum Propagate {
Named {
#[schemars(with = "String")]
#[serde(deserialize_with = "deserialize_header_name")]
named: HeaderName,
#[schemars(with = "Option<String>", default)]
#[serde(deserialize_with = "deserialize_option_header_name", default)]
rename: Option<HeaderName>,
#[schemars(with = "Option<String>", default)]
#[serde(deserialize_with = "deserialize_option_header_value", default)]
default: Option<HeaderValue>,
},
Matching {
#[schemars(schema_with = "propagate_matching")]
#[serde(deserialize_with = "deserialize_regex")]
matching: Regex,
},
}
#[derive(Clone, JsonSchema, Default, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields, default)]
struct Config {
all: Option<HeadersLocation>,
subgraphs: HashMap<String, HeadersLocation>,
}
struct Headers {
all_operations: Arc<Vec<Operation>>,
subgraph_operations: HashMap<String, Arc<Vec<Operation>>>,
reserved_headers: Arc<HashSet<&'static HeaderName>>,
}
#[async_trait::async_trait]
impl Plugin for Headers {
type Config = Config;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
let operations: Vec<Operation> = init
.config
.all
.as_ref()
.map(|a| a.request.clone())
.unwrap_or_default();
let subgraph_operations = init
.config
.subgraphs
.iter()
.map(|(subgraph_name, op)| {
let mut operations = operations.clone();
operations.append(&mut op.request.clone());
(subgraph_name.clone(), Arc::new(operations))
})
.collect();
Ok(Headers {
all_operations: Arc::new(operations),
subgraph_operations,
reserved_headers: Arc::new(RESERVED_HEADERS.iter().collect()),
})
}
fn subgraph_service(&self, name: &str, service: subgraph::BoxService) -> subgraph::BoxService {
ServiceBuilder::new()
.layer(HeadersLayer::new(
self.subgraph_operations
.get(name)
.cloned()
.unwrap_or_else(|| self.all_operations.clone()),
self.reserved_headers.clone(),
))
.service(service)
.boxed()
}
}
struct HeadersLayer {
operations: Arc<Vec<Operation>>,
reserved_headers: Arc<HashSet<&'static HeaderName>>,
}
impl HeadersLayer {
fn new(
operations: Arc<Vec<Operation>>,
reserved_headers: Arc<HashSet<&'static HeaderName>>,
) -> Self {
Self {
operations,
reserved_headers,
}
}
}
impl<S> Layer<S> for HeadersLayer {
type Service = HeadersService<S>;
fn layer(&self, inner: S) -> Self::Service {
HeadersService {
inner,
operations: self.operations.clone(),
reserved_headers: self.reserved_headers.clone(),
}
}
}
struct HeadersService<S> {
inner: S,
operations: Arc<Vec<Operation>>,
reserved_headers: Arc<HashSet<&'static HeaderName>>,
}
static RESERVED_HEADERS: [HeaderName; 14] = [
CONNECTION,
PROXY_AUTHENTICATE,
PROXY_AUTHORIZATION,
TE,
TRAILER,
TRANSFER_ENCODING,
UPGRADE,
CONTENT_LENGTH,
CONTENT_TYPE,
CONTENT_ENCODING,
HOST,
ACCEPT,
ACCEPT_ENCODING,
HeaderName::from_static("keep-alive"),
];
impl<S> Service<SubgraphRequest> for HeadersService<S>
where
S: Service<SubgraphRequest>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: SubgraphRequest) -> Self::Future {
self.modify_request(&mut req);
self.inner.call(req)
}
}
impl<S> HeadersService<S> {
fn modify_request(&self, req: &mut SubgraphRequest) {
let mut already_propagated: HashSet<&str> = HashSet::new();
for operation in &*self.operations {
match operation {
Operation::Insert(insert_config) => match insert_config {
Insert::Static(static_insert) => {
req.subgraph_request
.headers_mut()
.insert(&static_insert.name, static_insert.value.clone());
}
Insert::FromContext(insert_from_context) => {
if let Some(val) = req
.context
.get::<_, String>(&insert_from_context.from_context)
.ok()
.flatten()
{
match HeaderValue::from_str(&val) {
Ok(header_value) => {
req.subgraph_request
.headers_mut()
.insert(&insert_from_context.name, header_value);
}
Err(err) => {
tracing::error!(
"cannot convert from the context into a header value for header name '{}': {:?}",
insert_from_context.name,
err
);
}
}
}
}
Insert::FromBody(from_body) => {
let output = from_body
.path
.execute(req.supergraph_request.body())
.ok()
.flatten();
if let Some(val) = output {
let header_value = if let Value::String(val_str) = val {
val_str
} else {
val.to_string()
};
match HeaderValue::from_str(&header_value) {
Ok(header_value) => {
req.subgraph_request
.headers_mut()
.insert(&from_body.name, header_value);
}
Err(err) => {
tracing::error!(
"cannot convert from the body into a header value for header name '{}': {:?}",
from_body.name,
err
);
}
}
} else if let Some(default_val) = &from_body.default {
req.subgraph_request
.headers_mut()
.insert(&from_body.name, default_val.clone());
}
}
},
Operation::Remove(Remove::Named(name)) => {
req.subgraph_request.headers_mut().remove(name);
}
Operation::Remove(Remove::Matching(matching)) => {
let headers = req.subgraph_request.headers_mut();
let new_headers = headers
.drain()
.filter_map(|(name, value)| {
name.and_then(|name| {
(self.reserved_headers.contains(&name)
|| !matching.is_match(name.as_str()))
.then_some((name, value))
})
})
.collect();
let _ = std::mem::replace(headers, new_headers);
}
Operation::Propagate(Propagate::Named {
named,
rename,
default,
}) => {
let target_header = rename.as_ref().unwrap_or(named);
if !already_propagated.contains(target_header.as_str()) {
let headers = req.subgraph_request.headers_mut();
let values = req.supergraph_request.headers().get_all(named);
if values.iter().count() == 0 {
if let Some(default) = default {
headers.append(target_header, default.clone());
already_propagated.insert(target_header.as_str());
}
} else {
for value in values {
headers.append(target_header, value.clone());
already_propagated.insert(target_header.as_str());
}
}
}
}
Operation::Propagate(Propagate::Matching { matching }) => {
let mut previous_name = None;
let headers = req.subgraph_request.headers_mut();
req.supergraph_request
.headers()
.iter()
.filter(|(name, _)| {
!self.reserved_headers.contains(*name)
&& matching.is_match(name.as_str())
})
.for_each(|(name, value)| {
if !already_propagated.contains(name.as_str()) {
headers.append(name, value.clone());
match previous_name {
None => previous_name = Some(name),
Some(previous) => {
if previous != name {
already_propagated.insert(previous.as_str());
previous_name = Some(name);
}
}
}
}
});
if let Some(name) = previous_name {
already_propagated.insert(name.as_str());
}
}
}
}
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
use subgraph::SubgraphRequestId;
use tower::BoxError;
use super::*;
use crate::Context;
use crate::graphql;
use crate::graphql::Request;
use crate::plugin::test::MockSubgraphService;
use crate::plugins::test::PluginTestHarness;
use crate::query_planner::fetch::OperationKind;
use crate::services::SubgraphRequest;
use crate::services::SubgraphResponse;
#[test]
fn test_subgraph_config() {
serde_yaml::from_str::<Config>(
r#"
subgraphs:
products:
request:
- insert:
name: "test"
value: "test"
"#,
)
.unwrap();
}
#[test]
fn test_insert_config() {
serde_yaml::from_str::<Config>(
r#"
all:
request:
- insert:
name: "test"
value: "test"
"#,
)
.unwrap();
}
#[test]
fn test_remove_config() {
serde_yaml::from_str::<Config>(
r#"
all:
request:
- remove:
named: "test"
"#,
)
.unwrap();
serde_yaml::from_str::<Config>(
r#"
all:
request:
- remove:
matching: "d.*"
"#,
)
.unwrap();
assert!(
serde_yaml::from_str::<Config>(
r#"
all:
request:
- remove:
matching: "d.*["
"#,
)
.is_err()
);
}
#[test]
fn test_propagate_config() {
serde_yaml::from_str::<Config>(
r#"
all:
request:
- propagate:
named: "test"
"#,
)
.unwrap();
serde_yaml::from_str::<Config>(
r#"
all:
request:
- propagate:
named: "test"
rename: "bif"
"#,
)
.unwrap();
serde_yaml::from_str::<Config>(
r#"
all:
request:
- propagate:
named: "test"
rename: "bif"
default: "bof"
"#,
)
.unwrap();
serde_yaml::from_str::<Config>(
r#"
all:
request:
- propagate:
matching: "d.*"
"#,
)
.unwrap();
}
#[tokio::test]
async fn test_insert_static() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("c", "d"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Insert(Insert::Static(InsertStatic {
name: "c".try_into()?,
value: "d".try_into()?,
}))]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_insert_from_context() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("header_from_context", "my_value_from_context"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Insert(Insert::FromContext(
InsertFromContext {
name: "header_from_context".try_into()?,
from_context: "my_key".to_string(),
},
))]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_insert_from_request_body() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("header_from_request", "my_operation_name"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Insert(Insert::FromBody(InsertFromBody {
name: "header_from_request".try_into()?,
path: JSONQuery::parse(".operationName")?,
default: None,
}))]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_remove_exact() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| request.assert_headers(vec![("ac", "vac"), ("ab", "vab")]))
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Remove(Remove::Named("aa".try_into()?))]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_remove_matching() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| request.assert_headers(vec![("ac", "vac")]))
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Remove(Remove::Matching(Regex::from_str(
"a[ab]",
)?))]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_matching() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("da", "vda"),
("db", "vdb"),
("db", "vdb2"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Propagate(Propagate::Matching {
matching: Regex::from_str("d[ab]")?,
})]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_exact() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("da", "vda"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Propagate(Propagate::Named {
named: "da".try_into()?,
rename: None,
default: None,
})]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_exact_rename() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("ea", "vda"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Propagate(Propagate::Named {
named: "da".try_into()?,
rename: Some("ea".try_into()?),
default: None,
})]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_multiple() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("ra", "vda"),
("rb", "vda"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![
Operation::Propagate(Propagate::Named {
named: "da".try_into()?,
rename: Some("ra".try_into()?),
default: None,
}),
Operation::Propagate(Propagate::Named {
named: "da".try_into()?,
rename: Some("rb".try_into()?),
default: None,
}),
Operation::Propagate(Propagate::Named {
named: "db".try_into()?,
rename: Some("ra".try_into()?),
default: None,
}),
]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_exact_default() -> Result<(), BoxError> {
let mut mock = MockSubgraphService::new();
mock.expect_call()
.times(1)
.withf(|request| {
request.assert_headers(vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("ea", "defaulted"),
])
})
.returning(example_response);
let mut service = HeadersLayer::new(
Arc::new(vec![Operation::Propagate(Propagate::Named {
named: "ea".try_into()?,
rename: None,
default: Some("defaulted".try_into()?),
})]),
Arc::new(RESERVED_HEADERS.iter().collect()),
)
.layer(mock);
service.ready().await?.call(example_request()).await?;
Ok(())
}
#[tokio::test]
async fn test_propagate_reserved() -> Result<(), BoxError> {
let service = HeadersService {
inner: MockSubgraphService::new(),
operations: Arc::new(vec![Operation::Propagate(Propagate::Matching {
matching: Regex::from_str(".*")?,
})]),
reserved_headers: Arc::new(RESERVED_HEADERS.iter().collect()),
};
let mut request = SubgraphRequest {
supergraph_request: Arc::new(
http::Request::builder()
.header("da", "vda")
.header("db", "vdb")
.header("db", "vdb")
.header("db", "vdb2")
.header(HOST, "host")
.header(CONTENT_LENGTH, "2")
.header(CONTENT_TYPE, "graphql")
.header(CONTENT_ENCODING, "identity")
.header(ACCEPT, "application/json")
.header(ACCEPT_ENCODING, "gzip")
.body(
Request::builder()
.query("query")
.operation_name("my_operation_name")
.build(),
)
.expect("expecting valid request"),
),
subgraph_request: http::Request::builder()
.header("aa", "vaa")
.header("ab", "vab")
.header("ac", "vac")
.header(HOST, "rhost")
.header(CONTENT_LENGTH, "22")
.header(CONTENT_TYPE, "graphql")
.body(Request::builder().query("query").build())
.expect("expecting valid request"),
operation_kind: OperationKind::Query,
context: Context::new(),
subgraph_name: String::from("test").into(),
subscription_stream: None,
connection_closed_signal: None,
query_hash: Default::default(),
authorization: Default::default(),
executable_document: None,
id: SubgraphRequestId(String::new()),
};
service.modify_request(&mut request);
let headers = request
.subgraph_request
.headers()
.iter()
.map(|(name, value)| (name.as_str(), value.to_str().unwrap()))
.collect::<Vec<_>>();
assert_eq!(
headers,
vec![
("aa", "vaa"),
("ab", "vab"),
("ac", "vac"),
("host", "rhost"),
("content-length", "22"),
("content-type", "graphql"),
("da", "vda"),
("db", "vdb"),
("db", "vdb"),
("db", "vdb2"),
]
);
Ok(())
}
#[tokio::test]
async fn test_propagate_multiple_matching_rules() -> Result<(), BoxError> {
let service = HeadersService {
inner: MockSubgraphService::new(),
operations: Arc::new(vec![
Operation::Propagate(Propagate::Named {
named: HeaderName::from_static("dc"),
rename: None,
default: None,
}),
Operation::Propagate(Propagate::Matching {
matching: Regex::from_str("dc")?,
}),
]),
reserved_headers: Arc::new(RESERVED_HEADERS.iter().collect()),
};
let mut request = SubgraphRequest {
supergraph_request: Arc::new(
http::Request::builder()
.header("da", "vda")
.header("db", "vdb")
.header("dc", "vdb2")
.body(
Request::builder()
.query("query")
.operation_name("my_operation_name")
.build(),
)
.expect("expecting valid request"),
),
subgraph_request: http::Request::builder()
.header("aa", "vaa")
.header("ab", "vab")
.header("ac", "vac")
.body(Request::builder().query("query").build())
.expect("expecting valid request"),
operation_kind: OperationKind::Query,
context: Context::new(),
subgraph_name: String::from("test").into(),
subscription_stream: None,
connection_closed_signal: None,
query_hash: Default::default(),
authorization: Default::default(),
executable_document: None,
id: SubgraphRequestId(String::new()),
};
service.modify_request(&mut request);
let headers = request
.subgraph_request
.headers()
.iter()
.map(|(name, value)| (name.as_str(), value.to_str().unwrap()))
.collect::<Vec<_>>();
assert_eq!(
headers,
vec![("aa", "vaa"), ("ab", "vab"), ("ac", "vac"), ("dc", "vdb2"),]
);
Ok(())
}
fn example_response(req: SubgraphRequest) -> Result<SubgraphResponse, BoxError> {
Ok(SubgraphResponse::new_from_response(
http::Response::default(),
Context::new(),
req.subgraph_name.unwrap_or_default(),
SubgraphRequestId(String::new()),
))
}
fn example_request() -> SubgraphRequest {
let ctx = Context::new();
ctx.insert("my_key", "my_value_from_context".to_string())
.unwrap();
SubgraphRequest {
supergraph_request: Arc::new(
http::Request::builder()
.header("da", "vda")
.header("db", "vdb")
.header("db", "vdb")
.header("db", "vdb2")
.header(HOST, "host")
.header(CONTENT_LENGTH, "2")
.header(CONTENT_TYPE, "graphql")
.body(
Request::builder()
.query("query")
.operation_name("my_operation_name")
.build(),
)
.expect("expecting valid request"),
),
subgraph_request: http::Request::builder()
.header("aa", "vaa")
.header("ab", "vab")
.header("ac", "vac")
.header(HOST, "rhost")
.header(CONTENT_LENGTH, "22")
.header(CONTENT_TYPE, "graphql")
.body(Request::builder().query("query").build())
.expect("expecting valid request"),
operation_kind: OperationKind::Query,
context: ctx,
subgraph_name: String::from("test").into(),
subscription_stream: None,
connection_closed_signal: None,
query_hash: Default::default(),
authorization: Default::default(),
executable_document: None,
id: SubgraphRequestId(String::new()),
}
}
impl SubgraphRequest {
fn assert_headers(&self, headers: Vec<(&'static str, &'static str)>) -> bool {
let mut headers = headers.clone();
headers.push((HOST.as_str(), "rhost"));
headers.push((CONTENT_LENGTH.as_str(), "22"));
headers.push((CONTENT_TYPE.as_str(), "graphql"));
let actual_headers = self
.subgraph_request
.headers()
.iter()
.map(|(name, value)| (name.as_str(), value.to_str().unwrap()))
.collect::<HashSet<_>>();
assert_eq!(actual_headers, headers.into_iter().collect::<HashSet<_>>());
true
}
}
async fn assert_headers(
config: &'static str,
input: Vec<(&'static str, &'static str)>,
output: Vec<(&'static str, &'static str)>,
) {
let mut req = http::Request::builder();
for (name, value) in input.iter() {
req = req.header(*name, *value);
}
let test_harness = PluginTestHarness::<Headers>::builder()
.config(config)
.build()
.await;
let response = test_harness
.call_subgraph(
subgraph::Request::fake_builder()
.subgraph_name("test")
.supergraph_request(Arc::new(
req.body(graphql::Request::default())
.expect("valid request"),
))
.build(),
move |r| {
let output = output.clone();
async move {
let output = output.clone();
let headers = r.subgraph_request.headers();
for (name, value) in output.iter() {
if let Some(header) = headers.get(*name) {
assert_eq!(header.to_str().unwrap(), *value);
} else {
panic!("missing header {}", name);
}
}
Ok(subgraph::Response::fake_builder().build())
}
},
)
.await;
assert!(response.is_ok());
}
#[tokio::test]
async fn test_propagate_passthrough() {
assert_headers(
include_str!("fixtures/propagate_passthrough.router.yaml"),
vec![("a", "av"), ("c", "cv")],
vec![("a", "av"), ("b", "av"), ("c", "cv")],
)
.await;
assert_headers(
include_str!("fixtures/propagate_passthrough.router.yaml"),
vec![("b", "bv"), ("c", "cv")],
vec![("b", "bv"), ("c", "cv")],
)
.await;
}
#[tokio::test]
async fn test_propagate_passthrough_defaulted() {
assert_headers(
include_str!("fixtures/propagate_passthrough_defaulted.router.yaml"),
vec![("a", "av")],
vec![("b", "av")],
)
.await;
assert_headers(
include_str!("fixtures/propagate_passthrough_defaulted.router.yaml"),
vec![("b", "bv")],
vec![("b", "bv")],
)
.await;
assert_headers(
include_str!("fixtures/propagate_passthrough_defaulted.router.yaml"),
vec![("c", "cv")],
vec![("b", "defaulted")],
)
.await;
}
}