rama_http_headers/forwarded/
via.rs

1use crate::{Header, util};
2use rama_core::error::{ErrorContext, OpaqueError};
3use rama_http_types::{HeaderName, HeaderValue, header};
4use rama_net::forwarded::{ForwardedElement, ForwardedProtocol, ForwardedVersion, NodeId};
5
6/// The Via general header is added by proxies, both forward and reverse.
7///
8/// This header can appear in the request or response headers.
9/// It is used for tracking message forwards, avoiding request loops,
10/// and identifying the protocol capabilities of senders along the request/response chain.
11///
12/// It is recommended to use the [`Forwarded`](super::Forwarded) header instead if you can.
13///
14/// More info can be found at <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Via>.
15///
16/// # Syntax
17///
18/// ```text
19/// Via: [ <protocol-name> "/" ] <protocol-version> <host> [ ":" <port> ]
20/// Via: [ <protocol-name> "/" ] <protocol-version> <pseudonym>
21/// ```
22///
23/// # Example values
24///
25/// * `1.1 vegur`
26/// * `HTTP/1.1 GWA`
27/// * `1.0 fred, 1.1 p.example.net`
28/// * `HTTP/1.1 proxy.example.re, 1.1 edge_1`
29/// * `1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)`
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Via(Vec<ViaElement>);
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct ViaElement {
35    protocol: Option<ForwardedProtocol>,
36    version: ForwardedVersion,
37    node_id: NodeId,
38}
39
40impl From<ViaElement> for ForwardedElement {
41    fn from(via: ViaElement) -> Self {
42        let mut el = ForwardedElement::forwarded_by(via.node_id);
43        el.set_forwarded_version(via.version);
44        if let Some(protocol) = via.protocol {
45            el.set_forwarded_proto(protocol);
46        }
47        el
48    }
49}
50
51impl Header for Via {
52    fn name() -> &'static HeaderName {
53        &header::VIA
54    }
55
56    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(
57        values: &mut I,
58    ) -> Result<Self, crate::Error> {
59        util::csv::from_comma_delimited(values).map(Via)
60    }
61
62    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
63        use std::fmt;
64        struct Format<F>(F);
65        impl<F> fmt::Display for Format<F>
66        where
67            F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
68        {
69            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70                (self.0)(f)
71            }
72        }
73        let s = format!(
74            "{}",
75            Format(|f: &mut fmt::Formatter<'_>| {
76                util::csv::fmt_comma_delimited(&mut *f, self.0.iter())
77            })
78        );
79        values.extend(Some(HeaderValue::from_str(&s).unwrap()))
80    }
81}
82
83impl FromIterator<ViaElement> for Via {
84    fn from_iter<T>(iter: T) -> Self
85    where
86        T: IntoIterator<Item = ViaElement>,
87    {
88        Via(iter.into_iter().collect())
89    }
90}
91
92impl super::ForwardHeader for Via {
93    fn try_from_forwarded<'a, I>(input: I) -> Option<Self>
94    where
95        I: IntoIterator<Item = &'a ForwardedElement>,
96    {
97        let vec: Vec<_> = input
98            .into_iter()
99            .filter_map(|el| {
100                let node_id = el.ref_forwarded_by()?.clone();
101                let version = el.ref_forwarded_version()?;
102                let protocol = el.ref_forwarded_proto();
103                Some(ViaElement {
104                    protocol,
105                    version,
106                    node_id,
107                })
108            })
109            .collect();
110        if vec.is_empty() { None } else { Some(Via(vec)) }
111    }
112}
113
114impl IntoIterator for Via {
115    type Item = ForwardedElement;
116    type IntoIter = ViaIterator;
117
118    fn into_iter(self) -> Self::IntoIter {
119        ViaIterator(self.0.into_iter())
120    }
121}
122
123#[derive(Debug, Clone)]
124/// An iterator over the `Via` header's elements.
125pub struct ViaIterator(std::vec::IntoIter<ViaElement>);
126
127impl Iterator for ViaIterator {
128    type Item = ForwardedElement;
129
130    fn next(&mut self) -> Option<Self::Item> {
131        self.0.next().map(Into::into)
132    }
133}
134
135impl std::str::FromStr for ViaElement {
136    type Err = OpaqueError;
137
138    fn from_str(s: &str) -> Result<Self, Self::Err> {
139        let mut bytes = s.as_bytes();
140
141        bytes = trim_left(bytes);
142
143        let (protocol, version) = match bytes.iter().position(|b| *b == b'/' || *b == b' ') {
144            Some(index) => match bytes[index] {
145                b'/' => {
146                    let protocol: ForwardedProtocol = std::str::from_utf8(&bytes[..index])
147                        .context("parse via protocol as utf-8")?
148                        .try_into()
149                        .context("parse via utf-8 protocol as protocol")?;
150                    bytes = &bytes[index + 1..];
151                    let index = bytes.iter().position(|b| *b == b' ').ok_or_else(|| {
152                        OpaqueError::from_display("via str: missing space after protocol separator")
153                    })?;
154                    let version =
155                        ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
156                    bytes = &bytes[index + 1..];
157                    (Some(protocol), version)
158                }
159                b' ' => {
160                    let version =
161                        ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
162                    bytes = &bytes[index + 1..];
163                    (None, version)
164                }
165                _ => unreachable!(),
166            },
167            None => {
168                return Err(OpaqueError::from_display("via str: missing version"));
169            }
170        };
171
172        bytes = trim_right(trim_left(bytes));
173        let node_id = NodeId::from_bytes_lossy(bytes);
174
175        Ok(Self {
176            protocol,
177            version,
178            node_id,
179        })
180    }
181}
182
183impl std::fmt::Display for ViaElement {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        if let Some(ref proto) = self.protocol {
186            write!(f, "{proto}/")?;
187        }
188        write!(f, "{} {}", self.version, self.node_id)
189    }
190}
191
192fn trim_left(b: &[u8]) -> &[u8] {
193    let mut offset = 0;
194    while offset < b.len() && b[offset] == b' ' {
195        offset += 1;
196    }
197    &b[offset..]
198}
199
200fn trim_right(b: &[u8]) -> &[u8] {
201    if b.is_empty() {
202        return b;
203    }
204
205    let mut offset = b.len();
206    while offset > 0 && b[offset - 1] == b' ' {
207        offset -= 1;
208    }
209    &b[..offset]
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    use rama_http_types::HeaderValue;
217
218    macro_rules! test_header {
219        ($name: ident, $input: expr, $expected: expr) => {
220            #[test]
221            fn $name() {
222                assert_eq!(
223                    Via::decode(
224                        &mut $input
225                            .into_iter()
226                            .map(|s| HeaderValue::from_bytes(s.as_bytes()).unwrap())
227                            .collect::<Vec<_>>()
228                            .iter()
229                    )
230                    .ok(),
231                    $expected,
232                );
233            }
234        };
235    }
236
237    // Tests from the Docs
238    test_header!(
239        test1,
240        vec!["1.1 vegur"],
241        Some(Via(vec![ViaElement {
242            protocol: None,
243            version: ForwardedVersion::HTTP_11,
244            node_id: NodeId::try_from_str("vegur").unwrap(),
245        }]))
246    );
247    test_header!(
248        test2,
249        vec!["1.1     vegur    "],
250        Some(Via(vec![ViaElement {
251            protocol: None,
252            version: ForwardedVersion::HTTP_11,
253            node_id: NodeId::try_from_str("vegur").unwrap(),
254        }]))
255    );
256    test_header!(
257        test3,
258        vec!["1.0 fred, 1.1 p.example.net"],
259        Some(Via(vec![
260            ViaElement {
261                protocol: None,
262                version: ForwardedVersion::HTTP_10,
263                node_id: NodeId::try_from_str("fred").unwrap(),
264            },
265            ViaElement {
266                protocol: None,
267                version: ForwardedVersion::HTTP_11,
268                node_id: NodeId::try_from_str("p.example.net").unwrap(),
269            }
270        ]))
271    );
272    test_header!(
273        test4,
274        vec!["1.0 fred    ,    1.1 p.example.net   "],
275        Some(Via(vec![
276            ViaElement {
277                protocol: None,
278                version: ForwardedVersion::HTTP_10,
279                node_id: NodeId::try_from_str("fred").unwrap(),
280            },
281            ViaElement {
282                protocol: None,
283                version: ForwardedVersion::HTTP_11,
284                node_id: NodeId::try_from_str("p.example.net").unwrap(),
285            }
286        ]))
287    );
288    test_header!(
289        test5,
290        vec!["1.0 fred", "1.1 p.example.net"],
291        Some(Via(vec![
292            ViaElement {
293                protocol: None,
294                version: ForwardedVersion::HTTP_10,
295                node_id: NodeId::try_from_str("fred").unwrap(),
296            },
297            ViaElement {
298                protocol: None,
299                version: ForwardedVersion::HTTP_11,
300                node_id: NodeId::try_from_str("p.example.net").unwrap(),
301            }
302        ]))
303    );
304    test_header!(
305        test6,
306        vec!["HTTP/1.1 proxy.example.re, 1.1 edge_1"],
307        Some(Via(vec![
308            ViaElement {
309                protocol: Some(ForwardedProtocol::HTTP),
310                version: ForwardedVersion::HTTP_11,
311                node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
312            },
313            ViaElement {
314                protocol: None,
315                version: ForwardedVersion::HTTP_11,
316                node_id: NodeId::try_from_str("edge_1").unwrap(),
317            }
318        ]))
319    );
320    test_header!(
321        test7,
322        vec!["1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)"],
323        Some(Via(vec![ViaElement {
324            protocol: None,
325            version: ForwardedVersion::HTTP_11,
326            node_id: NodeId::try_from_str(
327                "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_"
328            )
329            .unwrap(),
330        }]))
331    );
332
333    #[test]
334    fn test_via_symmetric_encoder() {
335        for via_input in [
336            Via(vec![
337                ViaElement {
338                    protocol: None,
339                    version: ForwardedVersion::HTTP_10,
340                    node_id: NodeId::try_from_str("fred").unwrap(),
341                },
342                ViaElement {
343                    protocol: None,
344                    version: ForwardedVersion::HTTP_11,
345                    node_id: NodeId::try_from_str("p.example.net").unwrap(),
346                },
347            ]),
348            Via(vec![
349                ViaElement {
350                    protocol: Some(ForwardedProtocol::HTTP),
351                    version: ForwardedVersion::HTTP_11,
352                    node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
353                },
354                ViaElement {
355                    protocol: None,
356                    version: ForwardedVersion::HTTP_11,
357                    node_id: NodeId::try_from_str("edge_1").unwrap(),
358                },
359            ]),
360            Via(vec![ViaElement {
361                protocol: None,
362                version: ForwardedVersion::HTTP_11,
363                node_id: NodeId::try_from_str(
364                    "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_",
365                )
366                .unwrap(),
367            }]),
368        ] {
369            let mut values = Vec::new();
370            via_input.encode(&mut values);
371            let via_output = Via::decode(&mut values.iter()).unwrap();
372            assert_eq!(via_input, via_output);
373        }
374    }
375}