apollo-router 2.14.0-rc.2

A configurable, high-performance routing runtime for Apollo Federation 🚀
Documentation
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;

use futures::StreamExt;
use futures::stream::once;
use http::HeaderMap;
use http::HeaderValue;
use http_body_util::BodyExt;
use tokio::fs;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt as TowerServiceExt;

use super::recording::Recording;
use super::recording::RequestDetails;
use super::recording::ResponseDetails;
use super::recording::Subgraph;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::services::execution;
use crate::services::router;
use crate::services::subgraph;
use crate::services::supergraph;

const RECORD_HEADER: &str = "x-apollo-router-record";

/// Request recording configuration.
#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
struct RecordConfig {
    /// The recording plugin is disabled by default.
    enabled: bool,
    /// The path to the directory where recordings will be stored. Defaults to
    /// the current working directory.
    storage_path: Option<PathBuf>,
}

fn default_storage_path() -> PathBuf {
    std::env::current_dir().expect("failed to get current directory")
}

#[derive(Debug)]
struct Record {
    enabled: bool,
    supergraph_sdl: Arc<String>,
    storage_path: Arc<Path>,
}

register_plugin!("experimental", "record", Record);

#[async_trait::async_trait]
impl Plugin for Record {
    type Config = RecordConfig;

    async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
        let storage_path = init
            .config
            .storage_path
            .unwrap_or_else(default_storage_path);

        let plugin = Self {
            enabled: init.config.enabled,
            supergraph_sdl: init.supergraph_sdl.clone(),
            storage_path: storage_path.clone().into(),
        };

        if init.config.enabled {
            write_file(
                storage_path.into(),
                &PathBuf::from("README.md"),
                include_str!("recording-readme.md").as_bytes(),
            )
            .await?;
        }

        Ok(plugin)
    }

    fn router_service(&self, service: router::BoxService) -> router::BoxService {
        if !self.enabled {
            return service;
        }

        let dir = self.storage_path.clone();

        ServiceBuilder::new()
            .map_future(move |future| {
                let dir = dir.clone();

                async move {
                    let res: router::Response = future.await?;
                    let (parts, stream) = res.response.into_parts();

                    let headers = parts.headers.clone();
                    let context = res.context.clone();

                    let after_complete = once(async move {
                        let recording = context
                            .extensions()
                            .with_lock(|lock| lock.remove::<Recording>());

                        if let Some(mut recording) = recording {
                            let (headers, header_errors) = parse_headers(&headers);
                            recording.client_response.headers = headers;
                            recording.client_response.header_errors = header_errors;

                            let filename = recording.filename();
                            let contents = serde_json::to_value(recording)?;

                            tokio::spawn(async move {
                                tracing::info!("Writing recording to {:?}", filename);

                                write_file(
                                    dir,
                                    &filename,
                                    serde_json::to_string_pretty(&contents)?.as_bytes(),
                                )
                                .await?;

                                Ok::<(), BoxError>(())
                            })
                            .await??;
                        }
                        Ok::<Option<_>, BoxError>(None)
                    })
                    .filter_map(|a| async move { a.unwrap() });

                    let stream = stream.into_data_stream().chain(after_complete);

                    router::Response::http_response_builder()
                        .context(res.context)
                        .response(http::Response::from_parts(
                            parts,
                            router::body::from_result_stream(stream),
                        ))
                        .build()
                }
            })
            .service(service)
            .boxed()
    }

    fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
        if !self.enabled {
            return service;
        }

        let supergraph_sdl = self.supergraph_sdl.clone();

        ServiceBuilder::new()
            .map_request(move |req: supergraph::Request| {
                if is_introspection(&req) {
                    return req;
                }

                let recording_enabled =
                    if req.supergraph_request.headers().contains_key(RECORD_HEADER) {
                        req.context.extensions().with_lock(|lock| {
                            lock.insert(Recording {
                                supergraph_sdl: supergraph_sdl.clone().to_string(),
                                client_request: Default::default(),
                                client_response: Default::default(),
                                formatted_query_plan: Default::default(),
                                subgraph_fetches: Default::default(),
                            })
                        });
                        true
                    } else {
                        false
                    };

                if recording_enabled {
                    let query = req.supergraph_request.body().query.clone();
                    let operation_name = req.supergraph_request.body().operation_name.clone();
                    let variables = req.supergraph_request.body().variables.clone();
                    let (headers, header_errors) = parse_headers(req.supergraph_request.headers());
                    let method = req.supergraph_request.method().to_string();
                    let uri = req.supergraph_request.uri().to_string();

                    req.context.extensions().with_lock(|lock| {
                        if let Some(recording) = lock.get_mut::<Recording>() {
                            recording.client_request = RequestDetails {
                                query,
                                operation_name,
                                variables,
                                headers,
                                header_errors,
                                method,
                                uri,
                            };
                        }
                    });
                }
                req
            })
            .map_response(|res: supergraph::Response| {
                let context = res.context.clone();
                res.map_stream(move |chunk| {
                    context.extensions().with_lock(|lock| {
                        if let Some(recording) = lock.get_mut::<Recording>() {
                            recording.client_response.chunks.push(chunk.clone());
                        }
                    });

                    chunk
                })
            })
            .service(service)
            .boxed()
    }

    fn execution_service(&self, service: execution::BoxService) -> execution::BoxService {
        ServiceBuilder::new()
            .map_request(|req: execution::Request| {
                req.context.extensions().with_lock(|lock| {
                    if let Some(recording) = lock.get_mut::<Recording>() {
                        recording.formatted_query_plan =
                            req.query_plan.formatted_query_plan.clone();
                    }
                });
                req
            })
            .service(service)
            .boxed()
    }

    fn subgraph_service(
        &self,
        subgraph_name: &str,
        service: subgraph::BoxService,
    ) -> subgraph::BoxService {
        if !self.enabled {
            return service;
        }

        let subgraph_name = String::from(subgraph_name);

        ServiceBuilder::new()
            .map_future_with_request_data(
                |req: &subgraph::Request| {
                    let (headers, header_errors) = parse_headers(req.subgraph_request.headers());

                    RequestDetails {
                        query: req.subgraph_request.body().query.clone(),
                        operation_name: req.subgraph_request.body().operation_name.clone(),
                        variables: req.subgraph_request.body().variables.clone(),
                        headers,
                        header_errors,
                        method: req.subgraph_request.method().to_string(),
                        uri: req.subgraph_request.uri().to_string(),
                    }
                },
                move |req: RequestDetails, future| {
                    let subgraph_name = subgraph_name.clone();
                    async move {
                        let res: subgraph::ServiceResult = future.await;
                        let res = res?;

                        let operation_name = req
                            .operation_name
                            .clone()
                            .unwrap_or_else(|| "UnnamedOperation".to_string());

                        let (headers, header_errors) =
                            parse_headers(&res.response.headers().clone());

                        let subgraph = Subgraph {
                            subgraph_name,
                            response: ResponseDetails {
                                headers,
                                header_errors,
                                chunks: vec![res.response.body().clone()],
                            },
                            request: req,
                        };

                        res.context.extensions().with_lock(|lock| {
                            if let Some(recording) = lock.get_mut::<Recording>() {
                                if recording.subgraph_fetches.is_none() {
                                    recording.subgraph_fetches = Some(Default::default());
                                }

                                if let Some(fetches) = &mut recording.subgraph_fetches {
                                    fetches.insert(operation_name, subgraph);
                                }
                            }
                        });
                        Ok(res)
                    }
                },
            )
            .service(service)
            .boxed()
    }
}

async fn write_file(dir: Arc<Path>, path: &PathBuf, contents: &[u8]) -> Result<(), BoxError> {
    let path = dir.join(path);
    let dir = path.parent().ok_or("invalid record directory")?;
    fs::create_dir_all(dir).await?;
    fs::write(path, contents).await?;
    Ok(())
}

fn is_introspection(request: &supergraph::Request) -> bool {
    request.context.executable_document().is_some_and(|doc| {
        doc.operations
            .get(request.supergraph_request.body().operation_name.as_deref())
            .ok()
            .is_some_and(|op| {
                op.root_fields(&doc).all(|field| {
                    matches!(field.name.as_str(), "__typename" | "__schema" | "__type")
                })
            })
    })
}

/// Parse headers into a HashMap of keys/values and a HashMap of keys/errors
/// for any invalid UTF-8 values.
fn parse_headers(
    input: &HeaderMap<HeaderValue>,
) -> (HashMap<String, Vec<String>>, HashMap<String, Vec<String>>) {
    let mut headers: HashMap<String, Vec<String>> = HashMap::new();
    let mut header_errors: HashMap<String, Vec<String>> = HashMap::new();

    for (k, v) in input {
        let k = k.as_str().to_owned();
        match String::from_utf8(v.as_bytes().to_vec()) {
            Ok(v) => headers.entry(k).or_default().push(v),
            Err(e) => header_errors.entry(k).or_default().push(e.to_string()),
        };
    }
    (headers, header_errors)
}