Skip to main content

apollo_federation/connectors/models/
headers.rs

1#![deny(clippy::pedantic)]
2
3use std::error::Error;
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use apollo_compiler::Name;
9use apollo_compiler::Node;
10use apollo_compiler::ast::Value;
11use apollo_compiler::parser::SourceSpan;
12use either::Either;
13use http::HeaderName;
14use http::header;
15
16use crate::connectors::ConnectSpec;
17use crate::connectors::JSONSelection;
18use crate::connectors::header::HeaderValue;
19use crate::connectors::spec::http::HEADERS_ARGUMENT_NAME;
20use crate::connectors::spec::http::HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME;
21use crate::connectors::spec::http::HTTP_HEADER_MAPPING_NAME_ARGUMENT_NAME;
22use crate::connectors::spec::http::HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME;
23use crate::connectors::string_template;
24
25#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
26pub enum OriginatingDirective {
27    Source,
28    Connect,
29}
30
31#[derive(Clone)]
32pub struct Header {
33    pub name: HeaderName,
34    pub(crate) name_node: Option<Node<Value>>,
35    pub source: HeaderSource,
36    pub(crate) source_node: Option<Node<Value>>,
37    pub originating_directive: OriginatingDirective,
38}
39
40impl Header {
41    /// Get a list of headers from the `headers` argument in a `@connect` or `@source` directive.
42    pub(crate) fn from_http_arg(
43        http_arg: &[(Name, Node<Value>)],
44        originating_directive: OriginatingDirective,
45        spec: ConnectSpec,
46    ) -> Vec<Result<Self, HeaderParseError>> {
47        let Some(headers_arg) = http_arg
48            .iter()
49            .find_map(|(key, value)| (*key == HEADERS_ARGUMENT_NAME).then_some(value))
50        else {
51            return Vec::new();
52        };
53        if let Some(values) = headers_arg.as_list() {
54            values
55                .iter()
56                .map(|n| Self::from_single(n, originating_directive, spec))
57                .collect()
58        } else if headers_arg.as_object().is_some() {
59            vec![Self::from_single(headers_arg, originating_directive, spec)]
60        } else {
61            vec![Err(HeaderParseError::Other {
62                message: format!("`{HEADERS_ARGUMENT_NAME}` must be an object or list of objects"),
63                node: headers_arg.clone(),
64            })]
65        }
66    }
67
68    /// Create a single `Header` directly, not from schema. Mostly useful for testing.
69    pub fn from_values(
70        name: HeaderName,
71        source: HeaderSource,
72        originating_directive: OriginatingDirective,
73    ) -> Self {
74        Self {
75            name,
76            name_node: None,
77            source,
78            source_node: None,
79            originating_directive,
80        }
81    }
82
83    /// Build a single [`Self`] from a single entry in the `headers` arg.
84    fn from_single(
85        node: &Node<Value>,
86        originating_directive: OriginatingDirective,
87        spec: ConnectSpec,
88    ) -> Result<Self, HeaderParseError> {
89        let mappings = node.as_object().ok_or_else(|| HeaderParseError::Other {
90            message: "the HTTP header mapping is not an object".to_string(),
91            node: node.clone(),
92        })?;
93        let name_node = mappings
94            .iter()
95            .find_map(|(name, value)| {
96                (*name == HTTP_HEADER_MAPPING_NAME_ARGUMENT_NAME).then_some(value)
97            })
98            .ok_or_else(|| HeaderParseError::Other {
99                message: format!("missing `{HTTP_HEADER_MAPPING_NAME_ARGUMENT_NAME}` field"),
100                node: node.clone(),
101            })?;
102        let name = name_node
103            .as_str()
104            .ok_or_else(|| format!("`{HTTP_HEADER_MAPPING_NAME_ARGUMENT_NAME}` is not a string"))
105            .and_then(|name_str| {
106                HeaderName::try_from(name_str)
107                    .map_err(|_| format!("the value `{name_str}` is an invalid HTTP header name"))
108            })
109            .map_err(|message| HeaderParseError::Other {
110                message,
111                node: name_node.clone(),
112            })?;
113
114        if RESERVED_HEADERS.contains(&name) {
115            return Err(HeaderParseError::Other {
116                message: format!("header '{name}' is reserved and cannot be set by a connector"),
117                node: name_node.clone(),
118            });
119        }
120
121        let from = mappings
122            .iter()
123            .find(|(name, _value)| *name == HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME);
124        let value = mappings
125            .iter()
126            .find(|(name, _value)| *name == HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME);
127
128        match (from, value) {
129            (Some(_), None) if STATIC_HEADERS.contains(&name) => {
130                Err(HeaderParseError::Other{
131                    message: format!(
132                        "header '{name}' can't be set with `{HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME}`, only with `{HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME}`"
133                    ),
134                    node: name_node.clone()
135                })
136            }
137            (Some((_, from_node)), None) => {
138                from_node.as_str()
139                    .ok_or_else(|| format!("`{HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME}` is not a string"))
140                    .and_then(|from_str| {
141                        HeaderName::try_from(from_str).map_err(|_| {
142                            format!("the value `{from_str}` is an invalid HTTP header name")
143                        })
144                    })
145                    .map(|from| Self {
146                        name,
147                        name_node: Some(name_node.clone()),
148                        source: HeaderSource::From(from),
149                        source_node: Some(from_node.clone()),
150                        originating_directive
151                    })
152                    .map_err(|message| HeaderParseError::Other{ message, node: from_node.clone()})
153            }
154            (None, Some((_, value_node))) => {
155                value_node
156                    .as_str()
157                    .ok_or_else(|| HeaderParseError::Other{
158                        message: format!("`{HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME}` field in HTTP header mapping must be a string"),
159                        node: value_node.clone()
160                    })
161                    .and_then(|value_str| {
162                        HeaderValue::parse_with_spec(
163                            value_str,
164                            spec,
165                        )
166                        .map_err(|err| HeaderParseError::ValueError {err, node: value_node.clone()})
167                    })
168                    .map(|value| Self {
169                        name,
170                        name_node: Some(name_node.clone()),
171                        source: HeaderSource::Value(value),
172                        source_node: Some(value_node.clone()),
173                        originating_directive
174                    })
175            }
176            (None, None) => {
177                Err(HeaderParseError::Other {
178                    message: format!("either `{HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME}` or `{HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME}` must be set"),
179                    node: node.clone(),
180                })
181            },
182            (Some((from_name, _)), Some((value_name, _))) => {
183                Err(HeaderParseError::ConflictingArguments {
184                    message: format!("`{HTTP_HEADER_MAPPING_FROM_ARGUMENT_NAME}` and `{HTTP_HEADER_MAPPING_VALUE_ARGUMENT_NAME}` can't be set at the same time"),
185                    from_location: from_name.location(),
186                    value_location: value_name.location(),
187                })
188            }
189        }
190    }
191}
192
193#[allow(clippy::missing_fields_in_debug)]
194impl Debug for Header {
195    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196        f.debug_struct("Header")
197            .field("name", &self.name)
198            .field("source", &self.source)
199            .finish()
200    }
201}
202
203#[derive(Clone, Debug)]
204pub enum HeaderSource {
205    From(HeaderName),
206    Value(HeaderValue),
207}
208
209impl HeaderSource {
210    pub(crate) fn expressions(&self) -> impl Iterator<Item = &JSONSelection> {
211        match self {
212            HeaderSource::From(_) => Either::Left(std::iter::empty()),
213            HeaderSource::Value(value) => Either::Right(value.expressions().map(|e| &e.expression)),
214        }
215    }
216}
217
218#[derive(Debug)]
219pub(crate) enum HeaderParseError {
220    ValueError {
221        err: string_template::Error,
222        node: Node<Value>,
223    },
224    /// Both `value` and `from` are set
225    ConflictingArguments {
226        message: String,
227        from_location: Option<SourceSpan>,
228        value_location: Option<SourceSpan>,
229    },
230    Other {
231        message: String,
232        node: Node<Value>,
233    },
234}
235
236impl Display for HeaderParseError {
237    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238        match self {
239            Self::ConflictingArguments { message, .. } | Self::Other { message, .. } => {
240                write!(f, "{message}")
241            }
242            Self::ValueError { err, .. } => write!(f, "{err}"),
243        }
244    }
245}
246
247impl Error for HeaderParseError {}
248
249const RESERVED_HEADERS: [HeaderName; 11] = [
250    header::CONNECTION,
251    header::PROXY_AUTHENTICATE,
252    header::PROXY_AUTHORIZATION,
253    header::TE,
254    header::TRAILER,
255    header::TRANSFER_ENCODING,
256    header::UPGRADE,
257    header::CONTENT_LENGTH,
258    header::CONTENT_ENCODING,
259    header::ACCEPT_ENCODING,
260    HeaderName::from_static("keep-alive"),
261];
262
263const STATIC_HEADERS: [HeaderName; 3] = [header::CONTENT_TYPE, header::ACCEPT, header::HOST];