apollo_router/plugin/test/mock/
subgraph.rs1#![allow(missing_docs)] use std::collections::HashMap;
6use std::sync::Arc;
7use std::task::Poll;
8
9use apollo_compiler::ast::Definition;
10use apollo_compiler::ast::Document;
11use futures::future;
12use http::HeaderMap;
13use http::HeaderName;
14use http::HeaderValue;
15use http::StatusCode;
16use tower::BoxError;
17use tower::Service;
18
19use crate::graphql;
20use crate::graphql::Request;
21use crate::graphql::Response;
22use crate::json_ext::Object;
23use crate::plugins::subscription::notification::Handle;
24use crate::services::SubgraphRequest;
25use crate::services::SubgraphResponse;
26
27type MockResponses = HashMap<Request, Response>;
28
29#[derive(Default, Clone)]
30pub struct MockSubgraph {
31 mocks: Arc<MockResponses>,
33 extensions: Option<Object>,
34 subscription_stream: Option<Handle<String, graphql::Response>>,
35 map_request_fn:
36 Option<Arc<dyn (Fn(SubgraphRequest) -> SubgraphRequest) + Send + Sync + 'static>>,
37 headers: HeaderMap,
38}
39
40impl MockSubgraph {
41 pub fn new(mocks: MockResponses) -> Self {
42 Self {
43 mocks: Arc::new(
44 mocks
45 .into_iter()
46 .map(|(mut req, res)| {
47 normalize(&mut req);
48 (req, res)
49 })
50 .collect(),
51 ),
52 extensions: None,
53 subscription_stream: None,
54 map_request_fn: None,
55 headers: HeaderMap::new(),
56 }
57 }
58
59 pub fn builder() -> MockSubgraphBuilder {
60 MockSubgraphBuilder::default()
61 }
62
63 pub fn with_extensions(mut self, extensions: Object) -> Self {
64 self.extensions = Some(extensions);
65 self
66 }
67
68 pub fn with_subscription_stream(
69 mut self,
70 subscription_stream: Handle<String, graphql::Response>,
71 ) -> Self {
72 self.subscription_stream = Some(subscription_stream);
73 self
74 }
75
76 #[cfg(test)]
77 pub(crate) fn with_map_request<F>(mut self, map_request_fn: F) -> Self
78 where
79 F: (Fn(SubgraphRequest) -> SubgraphRequest) + Send + Sync + 'static,
80 {
81 self.map_request_fn = Some(Arc::new(map_request_fn));
82 self
83 }
84}
85
86#[derive(Default, Clone)]
88pub struct MockSubgraphBuilder {
89 mocks: MockResponses,
90 extensions: Option<Object>,
91 subscription_stream: Option<Handle<String, graphql::Response>>,
92 headers: HeaderMap,
93}
94impl MockSubgraphBuilder {
95 pub fn with_extensions(mut self, extensions: Object) -> Self {
96 self.extensions = Some(extensions);
97 self
98 }
99
100 pub fn with_json(mut self, request: serde_json::Value, response: serde_json::Value) -> Self {
104 let mut request = serde_json::from_value(request).unwrap();
105 normalize(&mut request);
106 self.mocks
107 .insert(request, serde_json::from_value(response).unwrap());
108 self
109 }
110
111 pub fn with_subscription_stream(
112 mut self,
113 subscription_stream: Handle<String, graphql::Response>,
114 ) -> Self {
115 self.subscription_stream = Some(subscription_stream);
116 self
117 }
118
119 pub fn with_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
120 self.headers.insert(name, value);
121 self
122 }
123
124 pub fn build(self) -> MockSubgraph {
125 MockSubgraph {
126 mocks: Arc::new(self.mocks),
127 extensions: self.extensions,
128 subscription_stream: self.subscription_stream,
129 map_request_fn: None,
130 headers: self.headers,
131 }
132 }
133}
134
135fn normalize(request: &mut Request) {
138 if let Some(q) = &request.query {
139 let mut doc = Document::parse(q.clone(), "request").unwrap();
140
141 if let Some(Definition::OperationDefinition(op)) = doc.definitions.first_mut() {
142 let o = op.make_mut();
143 o.name.take();
144 };
145
146 request.query = Some(doc.serialize().no_indent().to_string());
147 request.operation_name = None;
148 }
149}
150
151impl Service<SubgraphRequest> for MockSubgraph {
152 type Response = SubgraphResponse;
153
154 type Error = BoxError;
155
156 type Future = future::Ready<Result<Self::Response, Self::Error>>;
157
158 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
159 Poll::Ready(Ok(()))
160 }
161
162 fn call(&mut self, mut req: SubgraphRequest) -> Self::Future {
163 if let Some(map_request_fn) = &self.map_request_fn {
164 req = map_request_fn.clone()(req);
165 }
166 let body = req.subgraph_request.body_mut();
167
168 let subscription_stream = self.subscription_stream.clone();
169 if let Some(sub_stream) = &mut req.subscription_stream {
170 sub_stream
171 .try_send(Box::pin(
172 subscription_stream
173 .expect("must have a subscription stream set")
174 .into_stream(),
175 ))
176 .unwrap();
177 }
178
179 if let Some(serde_json_bytes::Value::Object(subscription_ext)) =
181 body.extensions.get_mut("subscription")
182 {
183 if let Some(callback_url) = subscription_ext.get_mut("callbackUrl") {
184 let mut cb_url = url::Url::parse(
185 callback_url
186 .as_str()
187 .expect("callbackUrl extension must be a string"),
188 )
189 .expect("callbackUrl must be a valid URL");
190 cb_url.path_segments_mut().unwrap().pop();
191 cb_url.path_segments_mut().unwrap().push("subscription_id");
192
193 *callback_url = serde_json_bytes::Value::String(cb_url.to_string().into());
194 }
195 if let Some(subscription_id) = subscription_ext.get_mut("subscriptionId") {
196 *subscription_id =
197 serde_json_bytes::Value::String("subscriptionId".to_string().into());
198 }
199 }
200
201 normalize(body);
202 let response = if let Some(response) = self.mocks.get(body) {
203 let mut http_response_builder = http::Response::builder().status(StatusCode::OK);
205 if let Some(headers) = http_response_builder.headers_mut() {
206 headers.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
207 }
208 let http_response = http_response_builder
209 .body(response.clone())
210 .expect("Response is serializable; qed");
211 SubgraphResponse::new_from_response(
212 http_response,
213 req.context,
214 "test".to_string(),
215 req.id,
216 )
217 } else {
218 let error = crate::error::Error::builder()
219 .message(format!(
220 "couldn't find mock for query {}",
221 serde_json::to_string(body).unwrap()
222 ))
223 .extension_code("FETCH_ERROR".to_string())
224 .extensions(self.extensions.clone().unwrap_or_default())
225 .build();
226 SubgraphResponse::fake_builder()
227 .error(error)
228 .context(req.context)
229 .subgraph_name(req.subgraph_name.clone())
230 .id(req.id)
231 .build()
232 };
233 future::ok(response)
234 }
235}