rama_http/headers/forwarded/
via.rs

1use crate::headers::{self, Header};
2use crate::{HeaderName, HeaderValue};
3use rama_core::error::{ErrorContext, OpaqueError};
4use rama_net::forwarded::{ForwardedElement, ForwardedProtocol, ForwardedVersion, NodeId};
5
6/// The Via general header is added by proxies, both forward and reverse.
7///
8/// This header can appear in the request or response headers.
9/// It is used for tracking message forwards, avoiding request loops,
10/// and identifying the protocol capabilities of senders along the request/response chain.
11///
12/// It is recommended to use the [`Forwarded`](super::Forwarded) header instead if you can.
13///
14/// More info can be found at <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Via>.
15///
16/// # Syntax
17///
18/// ```text
19/// Via: [ <protocol-name> "/" ] <protocol-version> <host> [ ":" <port> ]
20/// Via: [ <protocol-name> "/" ] <protocol-version> <pseudonym>
21/// ```
22///
23/// # Example values
24///
25/// * `1.1 vegur`
26/// * `HTTP/1.1 GWA`
27/// * `1.0 fred, 1.1 p.example.net`
28/// * `HTTP/1.1 proxy.example.re, 1.1 edge_1`
29/// * `1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)`
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Via(Vec<ViaElement>);
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct ViaElement {
35    protocol: Option<ForwardedProtocol>,
36    version: ForwardedVersion,
37    node_id: NodeId,
38}
39
40impl From<ViaElement> for ForwardedElement {
41    fn from(via: ViaElement) -> Self {
42        let mut el = ForwardedElement::forwarded_by(via.node_id);
43        el.set_forwarded_version(via.version);
44        if let Some(protocol) = via.protocol {
45            el.set_forwarded_proto(protocol);
46        }
47        el
48    }
49}
50
51impl Header for Via {
52    fn name() -> &'static HeaderName {
53        &crate::header::VIA
54    }
55
56    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(
57        values: &mut I,
58    ) -> Result<Self, headers::Error> {
59        crate::headers::util::csv::from_comma_delimited(values).map(Via)
60    }
61
62    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
63        use std::fmt;
64        struct Format<F>(F);
65        impl<F> fmt::Display for Format<F>
66        where
67            F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
68        {
69            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70                (self.0)(f)
71            }
72        }
73        let s = format!(
74            "{}",
75            Format(|f: &mut fmt::Formatter<'_>| {
76                crate::headers::util::csv::fmt_comma_delimited(&mut *f, self.0.iter())
77            })
78        );
79        values.extend(Some(HeaderValue::from_str(&s).unwrap()))
80    }
81}
82
83impl FromIterator<ViaElement> for Via {
84    fn from_iter<T>(iter: T) -> Self
85    where
86        T: IntoIterator<Item = ViaElement>,
87    {
88        Via(iter.into_iter().collect())
89    }
90}
91
92impl super::ForwardHeader for Via {
93    fn try_from_forwarded<'a, I>(input: I) -> Option<Self>
94    where
95        I: IntoIterator<Item = &'a ForwardedElement>,
96    {
97        let vec: Vec<_> = input
98            .into_iter()
99            .filter_map(|el| {
100                let node_id = el.ref_forwarded_by()?.clone();
101                let version = el.ref_forwarded_version()?;
102                let protocol = el.ref_forwarded_proto();
103                Some(ViaElement {
104                    protocol,
105                    version,
106                    node_id,
107                })
108            })
109            .collect();
110        if vec.is_empty() {
111            None
112        } else {
113            Some(Via(vec))
114        }
115    }
116}
117
118impl IntoIterator for Via {
119    type Item = ForwardedElement;
120    type IntoIter = ViaIterator;
121
122    fn into_iter(self) -> Self::IntoIter {
123        ViaIterator(self.0.into_iter())
124    }
125}
126
127#[derive(Debug, Clone)]
128/// An iterator over the `Via` header's elements.
129pub struct ViaIterator(std::vec::IntoIter<ViaElement>);
130
131impl Iterator for ViaIterator {
132    type Item = ForwardedElement;
133
134    fn next(&mut self) -> Option<Self::Item> {
135        self.0.next().map(Into::into)
136    }
137}
138
139impl std::str::FromStr for ViaElement {
140    type Err = OpaqueError;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        let mut bytes = s.as_bytes();
144
145        bytes = trim_left(bytes);
146
147        let (protocol, version) = match bytes.iter().position(|b| *b == b'/' || *b == b' ') {
148            Some(index) => match bytes[index] {
149                b'/' => {
150                    let protocol: ForwardedProtocol = std::str::from_utf8(&bytes[..index])
151                        .context("parse via protocol as utf-8")?
152                        .try_into()
153                        .context("parse via utf-8 protocol as protocol")?;
154                    bytes = &bytes[index + 1..];
155                    let index = bytes.iter().position(|b| *b == b' ').ok_or_else(|| {
156                        OpaqueError::from_display("via str: missing space after protocol separator")
157                    })?;
158                    let version =
159                        ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
160                    bytes = &bytes[index + 1..];
161                    (Some(protocol), version)
162                }
163                b' ' => {
164                    let version =
165                        ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
166                    bytes = &bytes[index + 1..];
167                    (None, version)
168                }
169                _ => unreachable!(),
170            },
171            None => {
172                return Err(OpaqueError::from_display("via str: missing version"));
173            }
174        };
175
176        bytes = trim_right(trim_left(bytes));
177        let node_id = NodeId::from_bytes_lossy(bytes);
178
179        Ok(Self {
180            protocol,
181            version,
182            node_id,
183        })
184    }
185}
186
187impl std::fmt::Display for ViaElement {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        if let Some(ref proto) = self.protocol {
190            write!(f, "{proto}/")?;
191        }
192        write!(f, "{} {}", self.version, self.node_id)
193    }
194}
195
196fn trim_left(b: &[u8]) -> &[u8] {
197    let mut offset = 0;
198    while offset < b.len() && b[offset] == b' ' {
199        offset += 1;
200    }
201    &b[offset..]
202}
203
204fn trim_right(b: &[u8]) -> &[u8] {
205    if b.is_empty() {
206        return b;
207    }
208
209    let mut offset = b.len();
210    while offset > 0 && b[offset - 1] == b' ' {
211        offset -= 1;
212    }
213    &b[..offset]
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    use rama_http_types::HeaderValue;
221
222    macro_rules! test_header {
223        ($name: ident, $input: expr, $expected: expr) => {
224            #[test]
225            fn $name() {
226                assert_eq!(
227                    Via::decode(
228                        &mut $input
229                            .into_iter()
230                            .map(|s| HeaderValue::from_bytes(s.as_bytes()).unwrap())
231                            .collect::<Vec<_>>()
232                            .iter()
233                    )
234                    .ok(),
235                    $expected,
236                );
237            }
238        };
239    }
240
241    // Tests from the Docs
242    test_header!(
243        test1,
244        vec!["1.1 vegur"],
245        Some(Via(vec![ViaElement {
246            protocol: None,
247            version: ForwardedVersion::HTTP_11,
248            node_id: NodeId::try_from_str("vegur").unwrap(),
249        }]))
250    );
251    test_header!(
252        test2,
253        vec!["1.1     vegur    "],
254        Some(Via(vec![ViaElement {
255            protocol: None,
256            version: ForwardedVersion::HTTP_11,
257            node_id: NodeId::try_from_str("vegur").unwrap(),
258        }]))
259    );
260    test_header!(
261        test3,
262        vec!["1.0 fred, 1.1 p.example.net"],
263        Some(Via(vec![
264            ViaElement {
265                protocol: None,
266                version: ForwardedVersion::HTTP_10,
267                node_id: NodeId::try_from_str("fred").unwrap(),
268            },
269            ViaElement {
270                protocol: None,
271                version: ForwardedVersion::HTTP_11,
272                node_id: NodeId::try_from_str("p.example.net").unwrap(),
273            }
274        ]))
275    );
276    test_header!(
277        test4,
278        vec!["1.0 fred    ,    1.1 p.example.net   "],
279        Some(Via(vec![
280            ViaElement {
281                protocol: None,
282                version: ForwardedVersion::HTTP_10,
283                node_id: NodeId::try_from_str("fred").unwrap(),
284            },
285            ViaElement {
286                protocol: None,
287                version: ForwardedVersion::HTTP_11,
288                node_id: NodeId::try_from_str("p.example.net").unwrap(),
289            }
290        ]))
291    );
292    test_header!(
293        test5,
294        vec!["1.0 fred", "1.1 p.example.net"],
295        Some(Via(vec![
296            ViaElement {
297                protocol: None,
298                version: ForwardedVersion::HTTP_10,
299                node_id: NodeId::try_from_str("fred").unwrap(),
300            },
301            ViaElement {
302                protocol: None,
303                version: ForwardedVersion::HTTP_11,
304                node_id: NodeId::try_from_str("p.example.net").unwrap(),
305            }
306        ]))
307    );
308    test_header!(
309        test6,
310        vec!["HTTP/1.1 proxy.example.re, 1.1 edge_1"],
311        Some(Via(vec![
312            ViaElement {
313                protocol: Some(ForwardedProtocol::HTTP),
314                version: ForwardedVersion::HTTP_11,
315                node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
316            },
317            ViaElement {
318                protocol: None,
319                version: ForwardedVersion::HTTP_11,
320                node_id: NodeId::try_from_str("edge_1").unwrap(),
321            }
322        ]))
323    );
324    test_header!(
325        test7,
326        vec!["1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)"],
327        Some(Via(vec![ViaElement {
328            protocol: None,
329            version: ForwardedVersion::HTTP_11,
330            node_id: NodeId::try_from_str(
331                "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_"
332            )
333            .unwrap(),
334        }]))
335    );
336
337    #[test]
338    fn test_via_symmetric_encoder() {
339        for via_input in [
340            Via(vec![
341                ViaElement {
342                    protocol: None,
343                    version: ForwardedVersion::HTTP_10,
344                    node_id: NodeId::try_from_str("fred").unwrap(),
345                },
346                ViaElement {
347                    protocol: None,
348                    version: ForwardedVersion::HTTP_11,
349                    node_id: NodeId::try_from_str("p.example.net").unwrap(),
350                },
351            ]),
352            Via(vec![
353                ViaElement {
354                    protocol: Some(ForwardedProtocol::HTTP),
355                    version: ForwardedVersion::HTTP_11,
356                    node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
357                },
358                ViaElement {
359                    protocol: None,
360                    version: ForwardedVersion::HTTP_11,
361                    node_id: NodeId::try_from_str("edge_1").unwrap(),
362                },
363            ]),
364            Via(vec![ViaElement {
365                protocol: None,
366                version: ForwardedVersion::HTTP_11,
367                node_id: NodeId::try_from_str(
368                    "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_",
369                )
370                .unwrap(),
371            }]),
372        ] {
373            let mut values = Vec::new();
374            via_input.encode(&mut values);
375            let via_output = Via::decode(&mut values.iter()).unwrap();
376            assert_eq!(via_input, via_output);
377        }
378    }
379}