Skip to main content

apollo_federation/connectors/runtime/
http_json_transport.rs

1use std::sync::Arc;
2
3use apollo_compiler::collections::IndexMap;
4use http::HeaderMap;
5use http::HeaderValue;
6use http::header::CONTENT_LENGTH;
7use http::header::CONTENT_TYPE;
8use parking_lot::Mutex;
9use serde_json_bytes::Value;
10use serde_json_bytes::json;
11use thiserror::Error;
12
13use super::form_encoding::encode_json_as_form;
14use crate::connectors::ApplyToError;
15use crate::connectors::HTTPMethod;
16use crate::connectors::Header;
17use crate::connectors::HeaderSource;
18use crate::connectors::HttpJsonTransport;
19use crate::connectors::MakeUriError;
20use crate::connectors::OriginatingDirective;
21use crate::connectors::ProblemLocation;
22use crate::connectors::runtime::debug::ConnectorContext;
23use crate::connectors::runtime::debug::ConnectorDebugHttpRequest;
24use crate::connectors::runtime::debug::DebugRequest;
25use crate::connectors::runtime::debug::SelectionData;
26use crate::connectors::runtime::mapping::Problem;
27use crate::connectors::runtime::mapping::aggregate_apply_to_errors;
28use crate::connectors::runtime::mapping::aggregate_apply_to_errors_with_problem_locations;
29
30/// Request to an HTTP transport
31#[derive(Debug)]
32pub struct HttpRequest {
33    pub inner: http::Request<String>,
34    pub debug: DebugRequest,
35}
36
37/// Response from an HTTP transport
38#[derive(Debug)]
39pub struct HttpResponse {
40    /// The response parts - the body is consumed by applying the JSON mapping
41    pub inner: http::response::Parts,
42}
43
44/// Request to an underlying transport
45#[derive(Debug)]
46pub enum TransportRequest {
47    /// A request to an HTTP transport
48    Http(Box<HttpRequest>),
49    /// A mapping-only request that skips the HTTP transport entirely.
50    /// The selection is applied against an empty object `{}`.
51    MappingOnly,
52}
53
54/// Response from an underlying transport
55#[derive(Debug)]
56pub enum TransportResponse {
57    /// A response from an HTTP transport
58    Http(HttpResponse),
59    /// A mapping-only response (no HTTP transport involved)
60    MappingOnly,
61}
62
63impl From<HttpRequest> for TransportRequest {
64    fn from(value: HttpRequest) -> Self {
65        Self::Http(Box::new(value))
66    }
67}
68
69impl From<HttpResponse> for TransportResponse {
70    fn from(value: HttpResponse) -> Self {
71        Self::Http(value)
72    }
73}
74
75pub fn make_request(
76    transport: &HttpJsonTransport,
77    inputs: IndexMap<String, Value>,
78    client_headers: &HeaderMap<HeaderValue>,
79    debug: &Option<Arc<Mutex<ConnectorContext>>>,
80) -> Result<(TransportRequest, Vec<Problem>), HttpJsonTransportError> {
81    let (uri, uri_apply_to_errors) = transport.make_uri(&inputs)?;
82    let uri_mapping_problems =
83        aggregate_apply_to_errors_with_problem_locations(uri_apply_to_errors);
84
85    let method = transport.method;
86    let request = http::Request::builder()
87        .method(transport.method.as_str())
88        .uri(uri);
89
90    // add the headers and if content-type is specified, we'll check that when constructing the body
91    let (mut request, is_form_urlencoded, header_apply_to_errors) = add_headers(
92        request,
93        client_headers,
94        &transport.headers,
95        &inputs,
96        transport.body.is_some(),
97    );
98    let header_mapping_problems =
99        aggregate_apply_to_errors_with_problem_locations(header_apply_to_errors);
100
101    let (json_body, form_body, body, content_length, body_apply_to_errors) =
102        if let Some(ref selection) = transport.body {
103            let (json_body, apply_to_errors) = selection.apply_with_vars(&json!({}), &inputs);
104            let mut form_body = None;
105            let (body, content_length) = if let Some(json_body) = json_body.as_ref() {
106                if is_form_urlencoded {
107                    let encoded = encode_json_as_form(json_body)
108                        .map_err(HttpJsonTransportError::FormBodySerialization)?;
109                    form_body = Some(encoded.clone());
110                    let len = encoded.len();
111                    (encoded, len)
112                } else {
113                    let bytes = serde_json::to_vec(json_body)?;
114                    let len = bytes.len();
115                    let body_string = serde_json::to_string(json_body)?;
116                    (body_string, len)
117                }
118            } else {
119                ("".into(), 0)
120            };
121            (json_body, form_body, body, content_length, apply_to_errors)
122        } else {
123            (None, None, "".into(), 0, vec![])
124        };
125
126    match method {
127        HTTPMethod::Post | HTTPMethod::Patch | HTTPMethod::Put => {
128            request = request.header(CONTENT_LENGTH, content_length);
129        }
130        _ => {}
131    }
132
133    let request = request
134        .body(body)
135        .map_err(HttpJsonTransportError::InvalidNewRequest)?;
136
137    let body_mapping_problems =
138        aggregate_apply_to_errors(body_apply_to_errors, ProblemLocation::RequestBody);
139
140    let all_problems: Vec<Problem> = uri_mapping_problems
141        .chain(body_mapping_problems)
142        .chain(header_mapping_problems)
143        .collect();
144
145    let debug_request = debug.as_ref().map(|_| {
146        if is_form_urlencoded {
147            Box::new(ConnectorDebugHttpRequest::new(
148                &request,
149                "form-urlencoded".to_string(),
150                form_body.map(|s| Value::String(s.into())).as_ref(),
151                transport.body.as_ref().map(|body| SelectionData {
152                    source: body.to_string(),
153                    transformed: body.to_string(), // no transformation so this is the same
154                    result: json_body,
155                }),
156                transport,
157            ))
158        } else {
159            Box::new(ConnectorDebugHttpRequest::new(
160                &request,
161                "json".to_string(),
162                json_body.as_ref(),
163                transport.body.as_ref().map(|body| SelectionData {
164                    source: body.to_string(),
165                    transformed: body.to_string(), // no transformation so this is the same
166                    result: json_body.clone(),
167                }),
168                transport,
169            ))
170        }
171    });
172
173    Ok((
174        TransportRequest::Http(Box::new(HttpRequest {
175            inner: request,
176            debug: (debug_request, all_problems.clone()),
177        })),
178        all_problems,
179    ))
180}
181
182fn add_headers(
183    mut request: http::request::Builder,
184    incoming_supergraph_headers: &HeaderMap<HeaderValue>,
185    config: &[Header],
186    inputs: &IndexMap<String, Value>,
187    has_body: bool,
188) -> (
189    http::request::Builder,
190    bool,
191    Vec<(ProblemLocation, ApplyToError)>,
192) {
193    let mut content_type = None;
194    let mut warnings = Vec::new();
195
196    for header in config {
197        match &header.source {
198            HeaderSource::From(from) => {
199                let values = incoming_supergraph_headers.get_all(from);
200                let mut propagated = false;
201                for value in values {
202                    request = request.header(header.name.clone(), value.clone());
203                    propagated = true;
204                }
205                if !propagated {
206                    tracing::warn!("Header '{}' not found in incoming request", header.name);
207                }
208            }
209            HeaderSource::Value(value) => match value.interpolate(inputs) {
210                Ok((value, apply_to_errors)) => {
211                    warnings.extend(apply_to_errors.iter().cloned().map(|e| {
212                        (
213                            match header.originating_directive {
214                                OriginatingDirective::Source => ProblemLocation::SourceHeaders,
215                                OriginatingDirective::Connect => ProblemLocation::ConnectHeaders,
216                            },
217                            e,
218                        )
219                    }));
220
221                    if header.name == CONTENT_TYPE {
222                        content_type = Some(value.clone());
223                    }
224
225                    request = request.header(header.name.clone(), value);
226                }
227                Err(err) => {
228                    tracing::error!("Unable to interpolate header value: {:?}", err);
229                }
230            },
231        }
232    }
233
234    let is_form_urlencoded = if let Some(content_type) = content_type {
235        // We don't need to set a content type here because it is set earlier in this function
236        let mine_type = content_type
237            .to_str()
238            .unwrap_or_default()
239            .parse::<mime::Mime>()
240            .ok();
241        mine_type.as_ref() == Some(&mime::APPLICATION_WWW_FORM_URLENCODED)
242    } else {
243        // Only set this content type header as a default if one hasn't been specified. This allows the user to override the value.
244        if has_body {
245            request = request.header(CONTENT_TYPE, mime::APPLICATION_JSON.essence_str());
246        }
247        false
248    };
249
250    (request, is_form_urlencoded, warnings)
251}
252
253#[derive(Error, Debug)]
254pub enum HttpJsonTransportError {
255    #[error("Could not generate HTTP request: {0}")]
256    InvalidNewRequest(#[source] http::Error),
257    #[error("Could not serialize body: {0}")]
258    JsonBodySerialization(#[from] serde_json::Error),
259    #[error("Could not serialize body: {0}")]
260    FormBodySerialization(&'static str),
261    #[error(transparent)]
262    MakeUri(#[from] MakeUriError),
263}
264
265#[cfg(test)]
266mod tests {
267    use std::str::FromStr;
268
269    use http::HeaderMap;
270    use http::HeaderValue;
271    use http::header::CONTENT_ENCODING;
272    use insta::assert_debug_snapshot;
273
274    use super::*;
275    use crate::connectors::HTTPMethod;
276    use crate::connectors::HeaderSource;
277    use crate::connectors::JSONSelection;
278    use crate::connectors::StringTemplate;
279
280    #[test]
281    fn test_headers_to_add_no_directives() {
282        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![
283            ("x-rename".parse().unwrap(), "renamed".parse().unwrap()),
284            ("x-rename".parse().unwrap(), "also-renamed".parse().unwrap()),
285            ("x-ignore".parse().unwrap(), "ignored".parse().unwrap()),
286            (CONTENT_ENCODING, "gzip".parse().unwrap()),
287        ]
288        .into_iter()
289        .collect();
290
291        let request = http::Request::builder();
292        let (request, ..) = add_headers(
293            request,
294            &incoming_supergraph_headers,
295            &[],
296            &IndexMap::with_hasher(Default::default()),
297            true,
298        );
299        let request = request.body("").unwrap();
300        assert_eq!(request.headers().len(), 1);
301        assert!(request.headers().get("content-type").is_some());
302    }
303
304    #[test]
305    fn test_headers_to_add_with_config() {
306        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![
307            ("x-rename".parse().unwrap(), "renamed".parse().unwrap()),
308            ("x-rename".parse().unwrap(), "also-renamed".parse().unwrap()),
309            ("x-ignore".parse().unwrap(), "ignored".parse().unwrap()),
310            (CONTENT_ENCODING, "gzip".parse().unwrap()),
311        ]
312        .into_iter()
313        .collect();
314
315        let config = vec![
316            Header::from_values(
317                "x-new-name".parse().unwrap(),
318                HeaderSource::From("x-rename".parse().unwrap()),
319                OriginatingDirective::Source,
320            ),
321            Header::from_values(
322                "x-insert".parse().unwrap(),
323                HeaderSource::Value("inserted".parse().unwrap()),
324                OriginatingDirective::Connect,
325            ),
326        ];
327
328        let request = http::Request::builder();
329        let (request, ..) = add_headers(
330            request,
331            &incoming_supergraph_headers,
332            &config,
333            &IndexMap::with_hasher(Default::default()),
334            true,
335        );
336        let request = request.body("").unwrap();
337        let result = request.headers();
338        assert_eq!(result.len(), 4);
339        assert_eq!(result.get("x-new-name"), Some(&"renamed".parse().unwrap()));
340        assert_eq!(result.get("x-insert"), Some(&"inserted".parse().unwrap()));
341    }
342
343    #[test]
344    fn test_headers_no_content_type_when_no_body() {
345        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![].into_iter().collect();
346
347        let config = vec![];
348
349        let request = http::Request::builder();
350        let (request, ..) = add_headers(
351            request,
352            &incoming_supergraph_headers,
353            &config,
354            &IndexMap::with_hasher(Default::default()),
355            false,
356        );
357        let request = request.body("").unwrap();
358        let result = request.headers();
359        assert_eq!(result.len(), 0);
360        assert!(result.get("content-type").is_none());
361    }
362
363    #[test]
364    fn test_headers_replace_default_content_type() {
365        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![(
366            "content-type".parse().unwrap(),
367            "application/json".parse().unwrap(),
368        )]
369        .into_iter()
370        .collect();
371
372        let config = vec![Header::from_values(
373            "content-type".parse().unwrap(),
374            HeaderSource::Value("application/vnd.iaas.v1+json".parse().unwrap()),
375            OriginatingDirective::Connect,
376        )];
377
378        let request = http::Request::builder();
379        let (request, ..) = add_headers(
380            request,
381            &incoming_supergraph_headers,
382            &config,
383            &IndexMap::with_hasher(Default::default()),
384            true,
385        );
386        let request = request.body("").unwrap();
387        let result = request.headers();
388        assert_eq!(result.len(), 1);
389        assert_eq!(
390            result.get("content-type"),
391            Some(&"application/vnd.iaas.v1+json".parse().unwrap())
392        );
393    }
394
395    #[test]
396    fn test_headers_multiple_content_type() {
397        let incoming_supergraph_headers: HeaderMap<HeaderValue> = vec![(
398            "content-type".parse().unwrap(),
399            "application/json".parse().unwrap(),
400        )]
401        .into_iter()
402        .collect();
403
404        let config = vec![
405            Header::from_values(
406                "content-type".parse().unwrap(),
407                HeaderSource::Value("application/json".parse().unwrap()),
408                OriginatingDirective::Connect,
409            ),
410            Header::from_values(
411                "content-type".parse().unwrap(),
412                HeaderSource::Value("application/vnd.iaas.v1+json".parse().unwrap()),
413                OriginatingDirective::Connect,
414            ),
415        ];
416
417        let request = http::Request::builder();
418        let (request, ..) = add_headers(
419            request,
420            &incoming_supergraph_headers,
421            &config,
422            &IndexMap::with_hasher(Default::default()),
423            true,
424        );
425        let request = request.body("").unwrap();
426        let result = request.headers();
427
428        let content_type_values: Vec<&HeaderValue> =
429            result.get_all("content-type").iter().collect();
430        assert_eq!(content_type_values.len(), 2);
431        assert_eq!(content_type_values[0], "application/json");
432        assert_eq!(content_type_values[1], "application/vnd.iaas.v1+json");
433    }
434
435    #[test]
436    fn make_request() {
437        let mut vars = IndexMap::default();
438        vars.insert("$args".to_string(), json!({ "a": 42 }));
439
440        let req = super::make_request(
441            &HttpJsonTransport {
442                source_template: None,
443                connect_template: StringTemplate::from_str("http://localhost:8080/").unwrap(),
444                method: HTTPMethod::Post,
445                body: Some(JSONSelection::parse("$args { a }").unwrap()),
446                ..Default::default()
447            },
448            vars,
449            &Default::default(),
450            &None,
451        )
452        .unwrap();
453
454        assert_debug_snapshot!(req, @r#"
455        (
456            Http(
457                HttpRequest {
458                    inner: Request {
459                        method: POST,
460                        uri: http://localhost:8080/,
461                        version: HTTP/1.1,
462                        headers: {
463                            "content-type": "application/json",
464                            "content-length": "8",
465                        },
466                        body: "{\"a\":42}",
467                    },
468                    debug: (
469                        None,
470                        [],
471                    ),
472                },
473            ),
474            [],
475        )
476        "#);
477
478        let TransportRequest::Http(http_request) = req.0 else {
479            panic!("expected Http transport request");
480        };
481        let HttpRequest { inner: req, .. } = *http_request;
482        let body = req.into_body();
483        insta::assert_snapshot!(body, @r#"{"a":42}"#);
484    }
485
486    #[test]
487    fn make_request_form_encoded() {
488        let mut vars = IndexMap::default();
489        vars.insert("$args".to_string(), json!({ "a": 42 }));
490        let headers = vec![Header::from_values(
491            "content-type".parse().unwrap(),
492            HeaderSource::Value("application/x-www-form-urlencoded".parse().unwrap()),
493            OriginatingDirective::Connect,
494        )];
495
496        let req = super::make_request(
497            &HttpJsonTransport {
498                source_template: None,
499                connect_template: StringTemplate::from_str("http://localhost:8080/").unwrap(),
500                method: HTTPMethod::Post,
501                headers,
502                body: Some(JSONSelection::parse("$args { a }").unwrap()),
503                ..Default::default()
504            },
505            vars,
506            &Default::default(),
507            &None,
508        )
509        .unwrap();
510
511        assert_debug_snapshot!(req, @r#"
512        (
513            Http(
514                HttpRequest {
515                    inner: Request {
516                        method: POST,
517                        uri: http://localhost:8080/,
518                        version: HTTP/1.1,
519                        headers: {
520                            "content-type": "application/x-www-form-urlencoded",
521                            "content-length": "4",
522                        },
523                        body: "a=42",
524                    },
525                    debug: (
526                        None,
527                        [],
528                    ),
529                },
530            ),
531            [],
532        )
533        "#);
534
535        let TransportRequest::Http(http_request) = req.0 else {
536            panic!("expected Http transport request");
537        };
538        let HttpRequest { inner: req, .. } = *http_request;
539        let body = req.into_body();
540        insta::assert_snapshot!(body, @r#"a=42"#);
541    }
542}