rama_http_headers/forwarded/
via.rs1use crate::{Header, util};
2use rama_core::error::{ErrorContext, OpaqueError};
3use rama_http_types::{HeaderName, HeaderValue, header};
4use rama_net::forwarded::{ForwardedElement, ForwardedProtocol, ForwardedVersion, NodeId};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Via(Vec<ViaElement>);
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct ViaElement {
35 protocol: Option<ForwardedProtocol>,
36 version: ForwardedVersion,
37 node_id: NodeId,
38}
39
40impl From<ViaElement> for ForwardedElement {
41 fn from(via: ViaElement) -> Self {
42 let mut el = ForwardedElement::forwarded_by(via.node_id);
43 el.set_forwarded_version(via.version);
44 if let Some(protocol) = via.protocol {
45 el.set_forwarded_proto(protocol);
46 }
47 el
48 }
49}
50
51impl Header for Via {
52 fn name() -> &'static HeaderName {
53 &header::VIA
54 }
55
56 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(
57 values: &mut I,
58 ) -> Result<Self, crate::Error> {
59 util::csv::from_comma_delimited(values).map(Via)
60 }
61
62 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
63 use std::fmt;
64 struct Format<F>(F);
65 impl<F> fmt::Display for Format<F>
66 where
67 F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
68 {
69 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70 (self.0)(f)
71 }
72 }
73 let s = format!(
74 "{}",
75 Format(|f: &mut fmt::Formatter<'_>| {
76 util::csv::fmt_comma_delimited(&mut *f, self.0.iter())
77 })
78 );
79 values.extend(Some(HeaderValue::from_str(&s).unwrap()))
80 }
81}
82
83impl FromIterator<ViaElement> for Via {
84 fn from_iter<T>(iter: T) -> Self
85 where
86 T: IntoIterator<Item = ViaElement>,
87 {
88 Via(iter.into_iter().collect())
89 }
90}
91
92impl super::ForwardHeader for Via {
93 fn try_from_forwarded<'a, I>(input: I) -> Option<Self>
94 where
95 I: IntoIterator<Item = &'a ForwardedElement>,
96 {
97 let vec: Vec<_> = input
98 .into_iter()
99 .filter_map(|el| {
100 let node_id = el.ref_forwarded_by()?.clone();
101 let version = el.ref_forwarded_version()?;
102 let protocol = el.ref_forwarded_proto();
103 Some(ViaElement {
104 protocol,
105 version,
106 node_id,
107 })
108 })
109 .collect();
110 if vec.is_empty() { None } else { Some(Via(vec)) }
111 }
112}
113
114impl IntoIterator for Via {
115 type Item = ForwardedElement;
116 type IntoIter = ViaIterator;
117
118 fn into_iter(self) -> Self::IntoIter {
119 ViaIterator(self.0.into_iter())
120 }
121}
122
123#[derive(Debug, Clone)]
124pub struct ViaIterator(std::vec::IntoIter<ViaElement>);
126
127impl Iterator for ViaIterator {
128 type Item = ForwardedElement;
129
130 fn next(&mut self) -> Option<Self::Item> {
131 self.0.next().map(Into::into)
132 }
133}
134
135impl std::str::FromStr for ViaElement {
136 type Err = OpaqueError;
137
138 fn from_str(s: &str) -> Result<Self, Self::Err> {
139 let mut bytes = s.as_bytes();
140
141 bytes = trim_left(bytes);
142
143 let (protocol, version) = match bytes.iter().position(|b| *b == b'/' || *b == b' ') {
144 Some(index) => match bytes[index] {
145 b'/' => {
146 let protocol: ForwardedProtocol = std::str::from_utf8(&bytes[..index])
147 .context("parse via protocol as utf-8")?
148 .try_into()
149 .context("parse via utf-8 protocol as protocol")?;
150 bytes = &bytes[index + 1..];
151 let index = bytes.iter().position(|b| *b == b' ').ok_or_else(|| {
152 OpaqueError::from_display("via str: missing space after protocol separator")
153 })?;
154 let version =
155 ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
156 bytes = &bytes[index + 1..];
157 (Some(protocol), version)
158 }
159 b' ' => {
160 let version =
161 ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
162 bytes = &bytes[index + 1..];
163 (None, version)
164 }
165 _ => unreachable!(),
166 },
167 None => {
168 return Err(OpaqueError::from_display("via str: missing version"));
169 }
170 };
171
172 bytes = trim_right(trim_left(bytes));
173 let node_id = NodeId::from_bytes_lossy(bytes);
174
175 Ok(Self {
176 protocol,
177 version,
178 node_id,
179 })
180 }
181}
182
183impl std::fmt::Display for ViaElement {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 if let Some(ref proto) = self.protocol {
186 write!(f, "{proto}/")?;
187 }
188 write!(f, "{} {}", self.version, self.node_id)
189 }
190}
191
192fn trim_left(b: &[u8]) -> &[u8] {
193 let mut offset = 0;
194 while offset < b.len() && b[offset] == b' ' {
195 offset += 1;
196 }
197 &b[offset..]
198}
199
200fn trim_right(b: &[u8]) -> &[u8] {
201 if b.is_empty() {
202 return b;
203 }
204
205 let mut offset = b.len();
206 while offset > 0 && b[offset - 1] == b' ' {
207 offset -= 1;
208 }
209 &b[..offset]
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 use rama_http_types::HeaderValue;
217
218 macro_rules! test_header {
219 ($name: ident, $input: expr, $expected: expr) => {
220 #[test]
221 fn $name() {
222 assert_eq!(
223 Via::decode(
224 &mut $input
225 .into_iter()
226 .map(|s| HeaderValue::from_bytes(s.as_bytes()).unwrap())
227 .collect::<Vec<_>>()
228 .iter()
229 )
230 .ok(),
231 $expected,
232 );
233 }
234 };
235 }
236
237 test_header!(
239 test1,
240 vec!["1.1 vegur"],
241 Some(Via(vec![ViaElement {
242 protocol: None,
243 version: ForwardedVersion::HTTP_11,
244 node_id: NodeId::try_from_str("vegur").unwrap(),
245 }]))
246 );
247 test_header!(
248 test2,
249 vec!["1.1 vegur "],
250 Some(Via(vec![ViaElement {
251 protocol: None,
252 version: ForwardedVersion::HTTP_11,
253 node_id: NodeId::try_from_str("vegur").unwrap(),
254 }]))
255 );
256 test_header!(
257 test3,
258 vec!["1.0 fred, 1.1 p.example.net"],
259 Some(Via(vec![
260 ViaElement {
261 protocol: None,
262 version: ForwardedVersion::HTTP_10,
263 node_id: NodeId::try_from_str("fred").unwrap(),
264 },
265 ViaElement {
266 protocol: None,
267 version: ForwardedVersion::HTTP_11,
268 node_id: NodeId::try_from_str("p.example.net").unwrap(),
269 }
270 ]))
271 );
272 test_header!(
273 test4,
274 vec!["1.0 fred , 1.1 p.example.net "],
275 Some(Via(vec![
276 ViaElement {
277 protocol: None,
278 version: ForwardedVersion::HTTP_10,
279 node_id: NodeId::try_from_str("fred").unwrap(),
280 },
281 ViaElement {
282 protocol: None,
283 version: ForwardedVersion::HTTP_11,
284 node_id: NodeId::try_from_str("p.example.net").unwrap(),
285 }
286 ]))
287 );
288 test_header!(
289 test5,
290 vec!["1.0 fred", "1.1 p.example.net"],
291 Some(Via(vec![
292 ViaElement {
293 protocol: None,
294 version: ForwardedVersion::HTTP_10,
295 node_id: NodeId::try_from_str("fred").unwrap(),
296 },
297 ViaElement {
298 protocol: None,
299 version: ForwardedVersion::HTTP_11,
300 node_id: NodeId::try_from_str("p.example.net").unwrap(),
301 }
302 ]))
303 );
304 test_header!(
305 test6,
306 vec!["HTTP/1.1 proxy.example.re, 1.1 edge_1"],
307 Some(Via(vec![
308 ViaElement {
309 protocol: Some(ForwardedProtocol::HTTP),
310 version: ForwardedVersion::HTTP_11,
311 node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
312 },
313 ViaElement {
314 protocol: None,
315 version: ForwardedVersion::HTTP_11,
316 node_id: NodeId::try_from_str("edge_1").unwrap(),
317 }
318 ]))
319 );
320 test_header!(
321 test7,
322 vec!["1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)"],
323 Some(Via(vec![ViaElement {
324 protocol: None,
325 version: ForwardedVersion::HTTP_11,
326 node_id: NodeId::try_from_str(
327 "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_"
328 )
329 .unwrap(),
330 }]))
331 );
332
333 #[test]
334 fn test_via_symmetric_encoder() {
335 for via_input in [
336 Via(vec![
337 ViaElement {
338 protocol: None,
339 version: ForwardedVersion::HTTP_10,
340 node_id: NodeId::try_from_str("fred").unwrap(),
341 },
342 ViaElement {
343 protocol: None,
344 version: ForwardedVersion::HTTP_11,
345 node_id: NodeId::try_from_str("p.example.net").unwrap(),
346 },
347 ]),
348 Via(vec![
349 ViaElement {
350 protocol: Some(ForwardedProtocol::HTTP),
351 version: ForwardedVersion::HTTP_11,
352 node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
353 },
354 ViaElement {
355 protocol: None,
356 version: ForwardedVersion::HTTP_11,
357 node_id: NodeId::try_from_str("edge_1").unwrap(),
358 },
359 ]),
360 Via(vec![ViaElement {
361 protocol: None,
362 version: ForwardedVersion::HTTP_11,
363 node_id: NodeId::try_from_str(
364 "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_",
365 )
366 .unwrap(),
367 }]),
368 ] {
369 let mut values = Vec::new();
370 via_input.encode(&mut values);
371 let via_output = Via::decode(&mut values.iter()).unwrap();
372 assert_eq!(via_input, via_output);
373 }
374 }
375}