Skip to main content

apollo_federation/connectors/runtime/
http_json_transport.rs

1use std::sync::Arc;
2
3use apollo_compiler::collections::IndexMap;
4use http::HeaderMap;
5use http::HeaderValue;
6use http::header::CONTENT_LENGTH;
7use http::header::CONTENT_TYPE;
8use parking_lot::Mutex;
9use serde_json_bytes::Value;
10use serde_json_bytes::json;
11use thiserror::Error;
12
13use super::form_encoding::encode_json_as_form;
14use crate::connectors::ApplyToError;
15use crate::connectors::HTTPMethod;
16use crate::connectors::Header;
17use crate::connectors::HeaderSource;
18use crate::connectors::HttpJsonTransport;
19use crate::connectors::MakeUriError;
20use crate::connectors::OriginatingDirective;
21use crate::connectors::ProblemLocation;
22use crate::connectors::runtime::debug::ConnectorContext;
23use crate::connectors::runtime::debug::ConnectorDebugHttpRequest;
24use crate::connectors::runtime::debug::DebugRequest;
25use crate::connectors::runtime::debug::SelectionData;
26use crate::connectors::runtime::mapping::Problem;
27use crate::connectors::runtime::mapping::aggregate_apply_to_errors;
28use crate::connectors::runtime::mapping::aggregate_apply_to_errors_with_problem_locations;
29
30/// Request to an HTTP transport
31#[derive(Debug)]
32pub struct HttpRequest {
33    pub inner: http::Request<String>,
34    pub debug: DebugRequest,
35}
36
37/// Response from an HTTP transport
38#[derive(Debug)]
39pub struct HttpResponse {
40    /// The response parts - the body is consumed by applying the JSON mapping
41    pub inner: http::response::Parts,
42}
43
44/// Request to an underlying transport
45#[derive(Debug)]
46pub enum TransportRequest {
47    /// A request to an HTTP transport
48    Http(HttpRequest),
49}
50
51/// Response from an underlying transport
52#[derive(Debug)]
53pub enum TransportResponse {
54    /// A response from an HTTP transport
55    Http(HttpResponse),
56}
57
58impl From<HttpRequest> for TransportRequest {
59    fn from(value: HttpRequest) -> Self {
60        Self::Http(value)
61    }
62}
63
64impl From<HttpResponse> for TransportResponse {
65    fn from(value: HttpResponse) -> Self {
66        Self::Http(value)
67    }
68}
69
70pub fn make_request(
71    transport: &HttpJsonTransport,
72    inputs: IndexMap<String, Value>,
73    client_headers: &HeaderMap<HeaderValue>,
74    debug: &Option<Arc<Mutex<ConnectorContext>>>,
75) -> Result<(TransportRequest, Vec<Problem>), HttpJsonTransportError> {
76    let (uri, uri_apply_to_errors) = transport.make_uri(&inputs)?;
77    let uri_mapping_problems =
78        aggregate_apply_to_errors_with_problem_locations(uri_apply_to_errors);
79
80    let method = transport.method;
81    let request = http::Request::builder()
82        .method(transport.method.as_str())
83        .uri(uri);
84
85    // add the headers and if content-type is specified, we'll check that when constructing the body
86    let (mut request, is_form_urlencoded, header_apply_to_errors) = add_headers(
87        request,
88        client_headers,
89        &transport.headers,
90        &inputs,
91        transport.body.is_some(),
92    );
93    let header_mapping_problems =
94        aggregate_apply_to_errors_with_problem_locations(header_apply_to_errors);
95
96    let (json_body, form_body, body, content_length, body_apply_to_errors) =
97        if let Some(ref selection) = transport.body {
98            let (json_body, apply_to_errors) = selection.apply_with_vars(&json!({}), &inputs);
99            let mut form_body = None;
100            let (body, content_length) = if let Some(json_body) = json_body.as_ref() {
101                if is_form_urlencoded {
102                    let encoded = encode_json_as_form(json_body)
103                        .map_err(HttpJsonTransportError::FormBodySerialization)?;
104                    form_body = Some(encoded.clone());
105                    let len = encoded.len();
106                    (encoded, len)
107                } else {
108                    let bytes = serde_json::to_vec(json_body)?;
109                    let len = bytes.len();
110                    let body_string = serde_json::to_string(json_body)?;
111                    (body_string, len)
112                }
113            } else {
114                ("".into(), 0)
115            };
116            (json_body, form_body, body, content_length, apply_to_errors)
117        } else {
118            (None, None, "".into(), 0, vec![])
119        };
120
121    match method {
122        HTTPMethod::Post | HTTPMethod::Patch | HTTPMethod::Put => {
123            request = request.header(CONTENT_LENGTH, content_length);
124        }
125        _ => {}
126    }
127
128    let request = request
129        .body(body)
130        .map_err(HttpJsonTransportError::InvalidNewRequest)?;
131
132    let body_mapping_problems =
133        aggregate_apply_to_errors(body_apply_to_errors, ProblemLocation::RequestBody);
134
135    let all_problems: Vec<Problem> = uri_mapping_problems
136        .chain(body_mapping_problems)
137        .chain(header_mapping_problems)
138        .collect();
139
140    let debug_request = debug.as_ref().map(|_| {
141        if is_form_urlencoded {
142            Box::new(ConnectorDebugHttpRequest::new(
143                &request,
144                "form-urlencoded".to_string(),
145                form_body.map(|s| Value::String(s.into())).as_ref(),
146                transport.body.as_ref().map(|body| SelectionData {
147                    source: body.to_string(),
148                    transformed: body.to_string(), // no transformation so this is the same
149                    result: json_body,
150                }),
151                transport,
152            ))
153        } else {
154            Box::new(ConnectorDebugHttpRequest::new(
155                &request,
156                "json".to_string(),
157                json_body.as_ref(),
158                transport.body.as_ref().map(|body| SelectionData {
159                    source: body.to_string(),
160                    transformed: body.to_string(), // no transformation so this is the same
161                    result: json_body.clone(),
162                }),
163                transport,
164            ))
165        }
166    });
167
168    Ok((
169        TransportRequest::Http(HttpRequest {
170            inner: request,
171            debug: (debug_request, all_problems.clone()),
172        }),
173        all_problems,
174    ))
175}
176
177fn add_headers(
178    mut request: http::request::Builder,
179    incoming_supergraph_headers: &HeaderMap<HeaderValue>,
180    config: &[Header],
181    inputs: &IndexMap<String, Value>,
182    has_body: bool,
183) -> (
184    http::request::Builder,
185    bool,
186    Vec<(ProblemLocation, ApplyToError)>,
187) {
188    let mut content_type = None;
189    let mut warnings = Vec::new();
190
191    for header in config {
192        match &header.source {
193            HeaderSource::From(from) => {
194                let values = incoming_supergraph_headers.get_all(from);
195                let mut propagated = false;
196                for value in values {
197                    request = request.header(header.name.clone(), value.clone());
198                    propagated = true;
199                }
200                if !propagated {
201                    tracing::warn!("Header '{}' not found in incoming request", header.name);
202                }
203            }
204            HeaderSource::Value(value) => match value.interpolate(inputs) {
205                Ok((value, apply_to_errors)) => {
206                    warnings.extend(apply_to_errors.iter().cloned().map(|e| {
207                        (
208                            match header.originating_directive {
209                                OriginatingDirective::Source => ProblemLocation::SourceHeaders,
210                                OriginatingDirective::Connect => ProblemLocation::ConnectHeaders,
211                            },
212                            e,
213                        )
214                    }));
215
216                    if header.name == CONTENT_TYPE {
217                        content_type = Some(value.clone());
218                    }
219
220                    request = request.header(header.name.clone(), value);
221                }
222                Err(err) => {
223                    tracing::error!("Unable to interpolate header value: {:?}", err);
224                }
225            },
226        }
227    }
228
229    let is_form_urlencoded = if let Some(content_type) = content_type {
230        // We don't need to set a content type here because it is set earlier in this function
231        let mine_type = content_type
232            .to_str()
233            .unwrap_or_default()
234            .parse::<mime::Mime>()
235            .ok();
236        mine_type.as_ref() == Some(&mime::APPLICATION_WWW_FORM_URLENCODED)
237    } else {
238        // Only set this content type header as a default if one hasn't been specified. This allows the user to override the value.
239        if has_body {
240            request = request.header(CONTENT_TYPE, mime::APPLICATION_JSON.essence_str());
241        }
242        false
243    };
244
245    (request, is_form_urlencoded, warnings)
246}
247
248#[derive(Error, Debug)]
249pub enum HttpJsonTransportError {
250    #[error("Could not generate HTTP request: {0}")]
251    InvalidNewRequest(#[source] http::Error),
252    #[error("Could not serialize body: {0}")]
253    JsonBodySerialization(#[from] serde_json::Error),
254    #[error("Could not serialize body: {0}")]
255    FormBodySerialization(&'static str),
256    #[error(transparent)]
257    MakeUri(#[from] MakeUriError),
258}
259
260#[cfg(test)]
261mod tests {
262    use std::str::FromStr;
263
264    use http::HeaderMap;
265    use http::HeaderValue;
266    use http::header::CONTENT_ENCODING;
267    use insta::assert_debug_snapshot;
268
269    use super::*;
270    use crate::connectors::HTTPMethod;
271    use crate::connectors::HeaderSource;
272    use crate::connectors::JSONSelection;
273    use crate::connectors::StringTemplate;
274
275    #[test]
276    fn test_headers_to_add_no_directives() {
277        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![
278            ("x-rename".parse().unwrap(), "renamed".parse().unwrap()),
279            ("x-rename".parse().unwrap(), "also-renamed".parse().unwrap()),
280            ("x-ignore".parse().unwrap(), "ignored".parse().unwrap()),
281            (CONTENT_ENCODING, "gzip".parse().unwrap()),
282        ]
283        .into_iter()
284        .collect();
285
286        let request = http::Request::builder();
287        let (request, ..) = add_headers(
288            request,
289            &incoming_supergraph_headers,
290            &[],
291            &IndexMap::with_hasher(Default::default()),
292            true,
293        );
294        let request = request.body("").unwrap();
295        assert_eq!(request.headers().len(), 1);
296        assert!(request.headers().get("content-type").is_some());
297    }
298
299    #[test]
300    fn test_headers_to_add_with_config() {
301        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![
302            ("x-rename".parse().unwrap(), "renamed".parse().unwrap()),
303            ("x-rename".parse().unwrap(), "also-renamed".parse().unwrap()),
304            ("x-ignore".parse().unwrap(), "ignored".parse().unwrap()),
305            (CONTENT_ENCODING, "gzip".parse().unwrap()),
306        ]
307        .into_iter()
308        .collect();
309
310        let config = vec![
311            Header::from_values(
312                "x-new-name".parse().unwrap(),
313                HeaderSource::From("x-rename".parse().unwrap()),
314                OriginatingDirective::Source,
315            ),
316            Header::from_values(
317                "x-insert".parse().unwrap(),
318                HeaderSource::Value("inserted".parse().unwrap()),
319                OriginatingDirective::Connect,
320            ),
321        ];
322
323        let request = http::Request::builder();
324        let (request, ..) = add_headers(
325            request,
326            &incoming_supergraph_headers,
327            &config,
328            &IndexMap::with_hasher(Default::default()),
329            true,
330        );
331        let request = request.body("").unwrap();
332        let result = request.headers();
333        assert_eq!(result.len(), 4);
334        assert_eq!(result.get("x-new-name"), Some(&"renamed".parse().unwrap()));
335        assert_eq!(result.get("x-insert"), Some(&"inserted".parse().unwrap()));
336    }
337
338    #[test]
339    fn test_headers_no_content_type_when_no_body() {
340        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![].into_iter().collect();
341
342        let config = vec![];
343
344        let request = http::Request::builder();
345        let (request, ..) = add_headers(
346            request,
347            &incoming_supergraph_headers,
348            &config,
349            &IndexMap::with_hasher(Default::default()),
350            false,
351        );
352        let request = request.body("").unwrap();
353        let result = request.headers();
354        assert_eq!(result.len(), 0);
355        assert!(result.get("content-type").is_none());
356    }
357
358    #[test]
359    fn test_headers_replace_default_content_type() {
360        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![(
361            "content-type".parse().unwrap(),
362            "application/json".parse().unwrap(),
363        )]
364        .into_iter()
365        .collect();
366
367        let config = vec![Header::from_values(
368            "content-type".parse().unwrap(),
369            HeaderSource::Value("application/vnd.iaas.v1+json".parse().unwrap()),
370            OriginatingDirective::Connect,
371        )];
372
373        let request = http::Request::builder();
374        let (request, ..) = add_headers(
375            request,
376            &incoming_supergraph_headers,
377            &config,
378            &IndexMap::with_hasher(Default::default()),
379            true,
380        );
381        let request = request.body("").unwrap();
382        let result = request.headers();
383        assert_eq!(result.len(), 1);
384        assert_eq!(
385            result.get("content-type"),
386            Some(&"application/vnd.iaas.v1+json".parse().unwrap())
387        );
388    }
389
390    #[test]
391    fn test_headers_multiple_content_type() {
392        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![(
393            "content-type".parse().unwrap(),
394            "application/json".parse().unwrap(),
395        )]
396        .into_iter()
397        .collect();
398
399        let config = vec![
400            Header::from_values(
401                "content-type".parse().unwrap(),
402                HeaderSource::Value("application/json".parse().unwrap()),
403                OriginatingDirective::Connect,
404            ),
405            Header::from_values(
406                "content-type".parse().unwrap(),
407                HeaderSource::Value("application/vnd.iaas.v1+json".parse().unwrap()),
408                OriginatingDirective::Connect,
409            ),
410        ];
411
412        let request = http::Request::builder();
413        let (request, ..) = add_headers(
414            request,
415            &incoming_supergraph_headers,
416            &config,
417            &IndexMap::with_hasher(Default::default()),
418            true,
419        );
420        let request = request.body("").unwrap();
421        let result = request.headers();
422
423        let content_type_values: Vec<&HeaderValue> =
424            result.get_all("content-type").iter().collect();
425        assert_eq!(content_type_values.len(), 2);
426        assert_eq!(content_type_values[0], "application/json");
427        assert_eq!(content_type_values[1], "application/vnd.iaas.v1+json");
428    }
429
430    #[test]
431    fn make_request() {
432        let mut vars = IndexMap::default();
433        vars.insert("$args".to_string(), json!({ "a": 42 }));
434
435        let req = super::make_request(
436            &HttpJsonTransport {
437                source_template: None,
438                connect_template: StringTemplate::from_str("http://localhost:8080/").unwrap(),
439                method: HTTPMethod::Post,
440                body: Some(JSONSelection::parse("$args { a }").unwrap()),
441                ..Default::default()
442            },
443            vars,
444            &Default::default(),
445            &None,
446        )
447        .unwrap();
448
449        assert_debug_snapshot!(req, @r#"
450        (
451            Http(
452                HttpRequest {
453                    inner: Request {
454                        method: POST,
455                        uri: http://localhost:8080/,
456                        version: HTTP/1.1,
457                        headers: {
458                            "content-type": "application/json",
459                            "content-length": "8",
460                        },
461                        body: "{\"a\":42}",
462                    },
463                    debug: (
464                        None,
465                        [],
466                    ),
467                },
468            ),
469            [],
470        )
471        "#);
472
473        let TransportRequest::Http(HttpRequest { inner: req, .. }) = req.0;
474        let body = req.into_body();
475        insta::assert_snapshot!(body, @r#"{"a":42}"#);
476    }
477
478    #[test]
479    fn make_request_form_encoded() {
480        let mut vars = IndexMap::default();
481        vars.insert("$args".to_string(), json!({ "a": 42 }));
482        let headers = vec![Header::from_values(
483            "content-type".parse().unwrap(),
484            HeaderSource::Value("application/x-www-form-urlencoded".parse().unwrap()),
485            OriginatingDirective::Connect,
486        )];
487
488        let req = super::make_request(
489            &HttpJsonTransport {
490                source_template: None,
491                connect_template: StringTemplate::from_str("http://localhost:8080/").unwrap(),
492                method: HTTPMethod::Post,
493                headers,
494                body: Some(JSONSelection::parse("$args { a }").unwrap()),
495                ..Default::default()
496            },
497            vars,
498            &Default::default(),
499            &None,
500        )
501        .unwrap();
502
503        assert_debug_snapshot!(req, @r#"
504        (
505            Http(
506                HttpRequest {
507                    inner: Request {
508                        method: POST,
509                        uri: http://localhost:8080/,
510                        version: HTTP/1.1,
511                        headers: {
512                            "content-type": "application/x-www-form-urlencoded",
513                            "content-length": "4",
514                        },
515                        body: "a=42",
516                    },
517                    debug: (
518                        None,
519                        [],
520                    ),
521                },
522            ),
523            [],
524        )
525        "#);
526
527        let TransportRequest::Http(HttpRequest { inner: req, .. }) = req.0;
528        let body = req.into_body();
529        insta::assert_snapshot!(body, @r#"a=42"#);
530    }
531}