apollo_router/plugin/test/mock/
subgraph.rs

1//! Mock subgraph implementation
2
3#![allow(missing_docs)] // FIXME
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::task::Poll;
8
9use apollo_compiler::ast::Definition;
10use apollo_compiler::ast::Document;
11use futures::future;
12use http::HeaderMap;
13use http::HeaderName;
14use http::HeaderValue;
15use http::StatusCode;
16use tower::BoxError;
17use tower::Service;
18
19use crate::graphql;
20use crate::graphql::Request;
21use crate::graphql::Response;
22use crate::json_ext::Object;
23use crate::plugins::subscription::notification::Handle;
24use crate::services::SubgraphRequest;
25use crate::services::SubgraphResponse;
26
27type MockResponses = HashMap<Request, Response>;
28
29#[derive(Default, Clone)]
30pub struct MockSubgraph {
31    // using an arc to improve efficiency when service is cloned
32    mocks: Arc<MockResponses>,
33    extensions: Option<Object>,
34    subscription_stream: Option<Handle<String, graphql::Response>>,
35    map_request_fn:
36        Option<Arc<dyn (Fn(SubgraphRequest) -> SubgraphRequest) + Send + Sync + 'static>>,
37    headers: HeaderMap,
38}
39
40impl MockSubgraph {
41    pub fn new(mocks: MockResponses) -> Self {
42        Self {
43            mocks: Arc::new(
44                mocks
45                    .into_iter()
46                    .map(|(mut req, res)| {
47                        normalize(&mut req);
48                        (req, res)
49                    })
50                    .collect(),
51            ),
52            extensions: None,
53            subscription_stream: None,
54            map_request_fn: None,
55            headers: HeaderMap::new(),
56        }
57    }
58
59    pub fn builder() -> MockSubgraphBuilder {
60        MockSubgraphBuilder::default()
61    }
62
63    pub fn with_extensions(mut self, extensions: Object) -> Self {
64        self.extensions = Some(extensions);
65        self
66    }
67
68    pub fn with_subscription_stream(
69        mut self,
70        subscription_stream: Handle<String, graphql::Response>,
71    ) -> Self {
72        self.subscription_stream = Some(subscription_stream);
73        self
74    }
75
76    #[cfg(test)]
77    pub(crate) fn with_map_request<F>(mut self, map_request_fn: F) -> Self
78    where
79        F: (Fn(SubgraphRequest) -> SubgraphRequest) + Send + Sync + 'static,
80    {
81        self.map_request_fn = Some(Arc::new(map_request_fn));
82        self
83    }
84}
85
86/// Builder for `MockSubgraph`
87#[derive(Default, Clone)]
88pub struct MockSubgraphBuilder {
89    mocks: MockResponses,
90    extensions: Option<Object>,
91    subscription_stream: Option<Handle<String, graphql::Response>>,
92    headers: HeaderMap,
93}
94impl MockSubgraphBuilder {
95    pub fn with_extensions(mut self, extensions: Object) -> Self {
96        self.extensions = Some(extensions);
97        self
98    }
99
100    /// adds a mocked response for a request
101    ///
102    /// the arguments must deserialize to `crate::graphql::Request` and `crate::graphql::Response`
103    pub fn with_json(mut self, request: serde_json::Value, response: serde_json::Value) -> Self {
104        let mut request = serde_json::from_value(request).unwrap();
105        normalize(&mut request);
106        self.mocks
107            .insert(request, serde_json::from_value(response).unwrap());
108        self
109    }
110
111    pub fn with_subscription_stream(
112        mut self,
113        subscription_stream: Handle<String, graphql::Response>,
114    ) -> Self {
115        self.subscription_stream = Some(subscription_stream);
116        self
117    }
118
119    pub fn with_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
120        self.headers.insert(name, value);
121        self
122    }
123
124    pub fn build(self) -> MockSubgraph {
125        MockSubgraph {
126            mocks: Arc::new(self.mocks),
127            extensions: self.extensions,
128            subscription_stream: self.subscription_stream,
129            map_request_fn: None,
130            headers: self.headers,
131        }
132    }
133}
134
135// Normalize queries so that spaces and operation names
136// don't have an impact on the cache
137fn normalize(request: &mut Request) {
138    if let Some(q) = &request.query {
139        let mut doc = Document::parse(q.clone(), "request").unwrap();
140
141        if let Some(Definition::OperationDefinition(op)) = doc.definitions.first_mut() {
142            let o = op.make_mut();
143            o.name.take();
144        };
145
146        request.query = Some(doc.serialize().no_indent().to_string());
147        request.operation_name = None;
148    }
149}
150
151impl Service<SubgraphRequest> for MockSubgraph {
152    type Response = SubgraphResponse;
153
154    type Error = BoxError;
155
156    type Future = future::Ready<Result<Self::Response, Self::Error>>;
157
158    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
159        Poll::Ready(Ok(()))
160    }
161
162    fn call(&mut self, mut req: SubgraphRequest) -> Self::Future {
163        if let Some(map_request_fn) = &self.map_request_fn {
164            req = map_request_fn.clone()(req);
165        }
166        let body = req.subgraph_request.body_mut();
167
168        let subscription_stream = self.subscription_stream.clone();
169        if let Some(sub_stream) = &mut req.subscription_stream {
170            sub_stream
171                .try_send(Box::pin(
172                    subscription_stream
173                        .expect("must have a subscription stream set")
174                        .into_stream(),
175                ))
176                .unwrap();
177        }
178
179        // Redact the callbackUrl and subscriptionId because it generates a subscription uuid
180        if let Some(serde_json_bytes::Value::Object(subscription_ext)) =
181            body.extensions.get_mut("subscription")
182        {
183            if let Some(callback_url) = subscription_ext.get_mut("callbackUrl") {
184                let mut cb_url = url::Url::parse(
185                    callback_url
186                        .as_str()
187                        .expect("callbackUrl extension must be a string"),
188                )
189                .expect("callbackUrl must be a valid URL");
190                cb_url.path_segments_mut().unwrap().pop();
191                cb_url.path_segments_mut().unwrap().push("subscription_id");
192
193                *callback_url = serde_json_bytes::Value::String(cb_url.to_string().into());
194            }
195            if let Some(subscription_id) = subscription_ext.get_mut("subscriptionId") {
196                *subscription_id =
197                    serde_json_bytes::Value::String("subscriptionId".to_string().into());
198            }
199        }
200
201        normalize(body);
202        let response = if let Some(response) = self.mocks.get(body) {
203            // Build an http Response
204            let mut http_response_builder = http::Response::builder().status(StatusCode::OK);
205            if let Some(headers) = http_response_builder.headers_mut() {
206                headers.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
207            }
208            let http_response = http_response_builder
209                .body(response.clone())
210                .expect("Response is serializable; qed");
211            SubgraphResponse::new_from_response(
212                http_response,
213                req.context,
214                "test".to_string(),
215                req.id,
216            )
217        } else {
218            let error = crate::error::Error::builder()
219                .message(format!(
220                    "couldn't find mock for query {}",
221                    serde_json::to_string(body).unwrap()
222                ))
223                .extension_code("FETCH_ERROR".to_string())
224                .extensions(self.extensions.clone().unwrap_or_default())
225                .build();
226            SubgraphResponse::fake_builder()
227                .error(error)
228                .context(req.context)
229                .subgraph_name(req.subgraph_name.clone())
230                .id(req.id)
231                .build()
232        };
233        future::ok(response)
234    }
235}