use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use futures::StreamExt;
use futures::stream::once;
use http::HeaderMap;
use http::HeaderValue;
use http_body_util::BodyExt;
use tokio::fs;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt as TowerServiceExt;
use super::recording::Recording;
use super::recording::RequestDetails;
use super::recording::ResponseDetails;
use super::recording::Subgraph;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::services::execution;
use crate::services::router;
use crate::services::subgraph;
use crate::services::supergraph;
const RECORD_HEADER: &str = "x-apollo-router-record";
#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
struct RecordConfig {
enabled: bool,
storage_path: Option<PathBuf>,
}
fn default_storage_path() -> PathBuf {
std::env::current_dir().expect("failed to get current directory")
}
#[derive(Debug)]
struct Record {
enabled: bool,
supergraph_sdl: Arc<String>,
storage_path: Arc<Path>,
}
register_plugin!("experimental", "record", Record);
#[async_trait::async_trait]
impl Plugin for Record {
type Config = RecordConfig;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
let storage_path = init
.config
.storage_path
.unwrap_or_else(default_storage_path);
let plugin = Self {
enabled: init.config.enabled,
supergraph_sdl: init.supergraph_sdl.clone(),
storage_path: storage_path.clone().into(),
};
if init.config.enabled {
write_file(
storage_path.into(),
&PathBuf::from("README.md"),
include_str!("recording-readme.md").as_bytes(),
)
.await?;
}
Ok(plugin)
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
if !self.enabled {
return service;
}
let dir = self.storage_path.clone();
ServiceBuilder::new()
.map_future(move |future| {
let dir = dir.clone();
async move {
let res: router::Response = future.await?;
let (parts, stream) = res.response.into_parts();
let headers = parts.headers.clone();
let context = res.context.clone();
let after_complete = once(async move {
let recording = context
.extensions()
.with_lock(|lock| lock.remove::<Recording>());
if let Some(mut recording) = recording {
let (headers, header_errors) = parse_headers(&headers);
recording.client_response.headers = headers;
recording.client_response.header_errors = header_errors;
let filename = recording.filename();
let contents = serde_json::to_value(recording)?;
tokio::spawn(async move {
tracing::info!("Writing recording to {:?}", filename);
write_file(
dir,
&filename,
serde_json::to_string_pretty(&contents)?.as_bytes(),
)
.await?;
Ok::<(), BoxError>(())
})
.await??;
}
Ok::<Option<_>, BoxError>(None)
})
.filter_map(|a| async move { a.unwrap() });
let stream = stream.into_data_stream().chain(after_complete);
router::Response::http_response_builder()
.context(res.context)
.response(http::Response::from_parts(
parts,
router::body::from_result_stream(stream),
))
.build()
}
})
.service(service)
.boxed()
}
fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
if !self.enabled {
return service;
}
let supergraph_sdl = self.supergraph_sdl.clone();
ServiceBuilder::new()
.map_request(move |req: supergraph::Request| {
if is_introspection(&req) {
return req;
}
let recording_enabled =
if req.supergraph_request.headers().contains_key(RECORD_HEADER) {
req.context.extensions().with_lock(|lock| {
lock.insert(Recording {
supergraph_sdl: supergraph_sdl.clone().to_string(),
client_request: Default::default(),
client_response: Default::default(),
formatted_query_plan: Default::default(),
subgraph_fetches: Default::default(),
})
});
true
} else {
false
};
if recording_enabled {
let query = req.supergraph_request.body().query.clone();
let operation_name = req.supergraph_request.body().operation_name.clone();
let variables = req.supergraph_request.body().variables.clone();
let (headers, header_errors) = parse_headers(req.supergraph_request.headers());
let method = req.supergraph_request.method().to_string();
let uri = req.supergraph_request.uri().to_string();
req.context.extensions().with_lock(|lock| {
if let Some(recording) = lock.get_mut::<Recording>() {
recording.client_request = RequestDetails {
query,
operation_name,
variables,
headers,
header_errors,
method,
uri,
};
}
});
}
req
})
.map_response(|res: supergraph::Response| {
let context = res.context.clone();
res.map_stream(move |chunk| {
context.extensions().with_lock(|lock| {
if let Some(recording) = lock.get_mut::<Recording>() {
recording.client_response.chunks.push(chunk.clone());
}
});
chunk
})
})
.service(service)
.boxed()
}
fn execution_service(&self, service: execution::BoxService) -> execution::BoxService {
ServiceBuilder::new()
.map_request(|req: execution::Request| {
req.context.extensions().with_lock(|lock| {
if let Some(recording) = lock.get_mut::<Recording>() {
recording.formatted_query_plan =
req.query_plan.formatted_query_plan.clone();
}
});
req
})
.service(service)
.boxed()
}
fn subgraph_service(
&self,
subgraph_name: &str,
service: subgraph::BoxService,
) -> subgraph::BoxService {
if !self.enabled {
return service;
}
let subgraph_name = String::from(subgraph_name);
ServiceBuilder::new()
.map_future_with_request_data(
|req: &subgraph::Request| {
let (headers, header_errors) = parse_headers(req.subgraph_request.headers());
RequestDetails {
query: req.subgraph_request.body().query.clone(),
operation_name: req.subgraph_request.body().operation_name.clone(),
variables: req.subgraph_request.body().variables.clone(),
headers,
header_errors,
method: req.subgraph_request.method().to_string(),
uri: req.subgraph_request.uri().to_string(),
}
},
move |req: RequestDetails, future| {
let subgraph_name = subgraph_name.clone();
async move {
let res: subgraph::ServiceResult = future.await;
let res = res?;
let operation_name = req
.operation_name
.clone()
.unwrap_or_else(|| "UnnamedOperation".to_string());
let (headers, header_errors) =
parse_headers(&res.response.headers().clone());
let subgraph = Subgraph {
subgraph_name,
response: ResponseDetails {
headers,
header_errors,
chunks: vec![res.response.body().clone()],
},
request: req,
};
res.context.extensions().with_lock(|lock| {
if let Some(recording) = lock.get_mut::<Recording>() {
if recording.subgraph_fetches.is_none() {
recording.subgraph_fetches = Some(Default::default());
}
if let Some(fetches) = &mut recording.subgraph_fetches {
fetches.insert(operation_name, subgraph);
}
}
});
Ok(res)
}
},
)
.service(service)
.boxed()
}
}
async fn write_file(dir: Arc<Path>, path: &PathBuf, contents: &[u8]) -> Result<(), BoxError> {
let path = dir.join(path);
let dir = path.parent().ok_or("invalid record directory")?;
fs::create_dir_all(dir).await?;
fs::write(path, contents).await?;
Ok(())
}
fn is_introspection(request: &supergraph::Request) -> bool {
request.context.executable_document().is_some_and(|doc| {
doc.operations
.get(request.supergraph_request.body().operation_name.as_deref())
.ok()
.is_some_and(|op| {
op.root_fields(&doc).all(|field| {
matches!(field.name.as_str(), "__typename" | "__schema" | "__type")
})
})
})
}
fn parse_headers(
input: &HeaderMap<HeaderValue>,
) -> (HashMap<String, Vec<String>>, HashMap<String, Vec<String>>) {
let mut headers: HashMap<String, Vec<String>> = HashMap::new();
let mut header_errors: HashMap<String, Vec<String>> = HashMap::new();
for (k, v) in input {
let k = k.as_str().to_owned();
match String::from_utf8(v.as_bytes().to_vec()) {
Ok(v) => headers.entry(k).or_default().push(v),
Err(e) => header_errors.entry(k).or_default().push(e.to_string()),
};
}
(headers, header_errors)
}