rama_http/headers/forwarded/
via.rs1use crate::headers::{self, Header};
2use crate::{HeaderName, HeaderValue};
3use rama_core::error::{ErrorContext, OpaqueError};
4use rama_net::forwarded::{ForwardedElement, ForwardedProtocol, ForwardedVersion, NodeId};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Via(Vec<ViaElement>);
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct ViaElement {
35 protocol: Option<ForwardedProtocol>,
36 version: ForwardedVersion,
37 node_id: NodeId,
38}
39
40impl From<ViaElement> for ForwardedElement {
41 fn from(via: ViaElement) -> Self {
42 let mut el = ForwardedElement::forwarded_by(via.node_id);
43 el.set_forwarded_version(via.version);
44 if let Some(protocol) = via.protocol {
45 el.set_forwarded_proto(protocol);
46 }
47 el
48 }
49}
50
51impl Header for Via {
52 fn name() -> &'static HeaderName {
53 &crate::header::VIA
54 }
55
56 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(
57 values: &mut I,
58 ) -> Result<Self, headers::Error> {
59 crate::headers::util::csv::from_comma_delimited(values).map(Via)
60 }
61
62 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
63 use std::fmt;
64 struct Format<F>(F);
65 impl<F> fmt::Display for Format<F>
66 where
67 F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result,
68 {
69 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70 (self.0)(f)
71 }
72 }
73 let s = format!(
74 "{}",
75 Format(|f: &mut fmt::Formatter<'_>| {
76 crate::headers::util::csv::fmt_comma_delimited(&mut *f, self.0.iter())
77 })
78 );
79 values.extend(Some(HeaderValue::from_str(&s).unwrap()))
80 }
81}
82
83impl FromIterator<ViaElement> for Via {
84 fn from_iter<T>(iter: T) -> Self
85 where
86 T: IntoIterator<Item = ViaElement>,
87 {
88 Via(iter.into_iter().collect())
89 }
90}
91
92impl super::ForwardHeader for Via {
93 fn try_from_forwarded<'a, I>(input: I) -> Option<Self>
94 where
95 I: IntoIterator<Item = &'a ForwardedElement>,
96 {
97 let vec: Vec<_> = input
98 .into_iter()
99 .filter_map(|el| {
100 let node_id = el.ref_forwarded_by()?.clone();
101 let version = el.ref_forwarded_version()?;
102 let protocol = el.ref_forwarded_proto();
103 Some(ViaElement {
104 protocol,
105 version,
106 node_id,
107 })
108 })
109 .collect();
110 if vec.is_empty() {
111 None
112 } else {
113 Some(Via(vec))
114 }
115 }
116}
117
118impl IntoIterator for Via {
119 type Item = ForwardedElement;
120 type IntoIter = ViaIterator;
121
122 fn into_iter(self) -> Self::IntoIter {
123 ViaIterator(self.0.into_iter())
124 }
125}
126
127#[derive(Debug, Clone)]
128pub struct ViaIterator(std::vec::IntoIter<ViaElement>);
130
131impl Iterator for ViaIterator {
132 type Item = ForwardedElement;
133
134 fn next(&mut self) -> Option<Self::Item> {
135 self.0.next().map(Into::into)
136 }
137}
138
139impl std::str::FromStr for ViaElement {
140 type Err = OpaqueError;
141
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 let mut bytes = s.as_bytes();
144
145 bytes = trim_left(bytes);
146
147 let (protocol, version) = match bytes.iter().position(|b| *b == b'/' || *b == b' ') {
148 Some(index) => match bytes[index] {
149 b'/' => {
150 let protocol: ForwardedProtocol = std::str::from_utf8(&bytes[..index])
151 .context("parse via protocol as utf-8")?
152 .try_into()
153 .context("parse via utf-8 protocol as protocol")?;
154 bytes = &bytes[index + 1..];
155 let index = bytes.iter().position(|b| *b == b' ').ok_or_else(|| {
156 OpaqueError::from_display("via str: missing space after protocol separator")
157 })?;
158 let version =
159 ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
160 bytes = &bytes[index + 1..];
161 (Some(protocol), version)
162 }
163 b' ' => {
164 let version =
165 ForwardedVersion::try_from(&bytes[..index]).context("parse via version")?;
166 bytes = &bytes[index + 1..];
167 (None, version)
168 }
169 _ => unreachable!(),
170 },
171 None => {
172 return Err(OpaqueError::from_display("via str: missing version"));
173 }
174 };
175
176 bytes = trim_right(trim_left(bytes));
177 let node_id = NodeId::from_bytes_lossy(bytes);
178
179 Ok(Self {
180 protocol,
181 version,
182 node_id,
183 })
184 }
185}
186
187impl std::fmt::Display for ViaElement {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 if let Some(ref proto) = self.protocol {
190 write!(f, "{proto}/")?;
191 }
192 write!(f, "{} {}", self.version, self.node_id)
193 }
194}
195
196fn trim_left(b: &[u8]) -> &[u8] {
197 let mut offset = 0;
198 while offset < b.len() && b[offset] == b' ' {
199 offset += 1;
200 }
201 &b[offset..]
202}
203
204fn trim_right(b: &[u8]) -> &[u8] {
205 if b.is_empty() {
206 return b;
207 }
208
209 let mut offset = b.len();
210 while offset > 0 && b[offset - 1] == b' ' {
211 offset -= 1;
212 }
213 &b[..offset]
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 use rama_http_types::HeaderValue;
221
222 macro_rules! test_header {
223 ($name: ident, $input: expr, $expected: expr) => {
224 #[test]
225 fn $name() {
226 assert_eq!(
227 Via::decode(
228 &mut $input
229 .into_iter()
230 .map(|s| HeaderValue::from_bytes(s.as_bytes()).unwrap())
231 .collect::<Vec<_>>()
232 .iter()
233 )
234 .ok(),
235 $expected,
236 );
237 }
238 };
239 }
240
241 test_header!(
243 test1,
244 vec!["1.1 vegur"],
245 Some(Via(vec![ViaElement {
246 protocol: None,
247 version: ForwardedVersion::HTTP_11,
248 node_id: NodeId::try_from_str("vegur").unwrap(),
249 }]))
250 );
251 test_header!(
252 test2,
253 vec!["1.1 vegur "],
254 Some(Via(vec![ViaElement {
255 protocol: None,
256 version: ForwardedVersion::HTTP_11,
257 node_id: NodeId::try_from_str("vegur").unwrap(),
258 }]))
259 );
260 test_header!(
261 test3,
262 vec!["1.0 fred, 1.1 p.example.net"],
263 Some(Via(vec![
264 ViaElement {
265 protocol: None,
266 version: ForwardedVersion::HTTP_10,
267 node_id: NodeId::try_from_str("fred").unwrap(),
268 },
269 ViaElement {
270 protocol: None,
271 version: ForwardedVersion::HTTP_11,
272 node_id: NodeId::try_from_str("p.example.net").unwrap(),
273 }
274 ]))
275 );
276 test_header!(
277 test4,
278 vec!["1.0 fred , 1.1 p.example.net "],
279 Some(Via(vec![
280 ViaElement {
281 protocol: None,
282 version: ForwardedVersion::HTTP_10,
283 node_id: NodeId::try_from_str("fred").unwrap(),
284 },
285 ViaElement {
286 protocol: None,
287 version: ForwardedVersion::HTTP_11,
288 node_id: NodeId::try_from_str("p.example.net").unwrap(),
289 }
290 ]))
291 );
292 test_header!(
293 test5,
294 vec!["1.0 fred", "1.1 p.example.net"],
295 Some(Via(vec![
296 ViaElement {
297 protocol: None,
298 version: ForwardedVersion::HTTP_10,
299 node_id: NodeId::try_from_str("fred").unwrap(),
300 },
301 ViaElement {
302 protocol: None,
303 version: ForwardedVersion::HTTP_11,
304 node_id: NodeId::try_from_str("p.example.net").unwrap(),
305 }
306 ]))
307 );
308 test_header!(
309 test6,
310 vec!["HTTP/1.1 proxy.example.re, 1.1 edge_1"],
311 Some(Via(vec![
312 ViaElement {
313 protocol: Some(ForwardedProtocol::HTTP),
314 version: ForwardedVersion::HTTP_11,
315 node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
316 },
317 ViaElement {
318 protocol: None,
319 version: ForwardedVersion::HTTP_11,
320 node_id: NodeId::try_from_str("edge_1").unwrap(),
321 }
322 ]))
323 );
324 test_header!(
325 test7,
326 vec!["1.1 2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net (CloudFront)"],
327 Some(Via(vec![ViaElement {
328 protocol: None,
329 version: ForwardedVersion::HTTP_11,
330 node_id: NodeId::try_from_str(
331 "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_"
332 )
333 .unwrap(),
334 }]))
335 );
336
337 #[test]
338 fn test_via_symmetric_encoder() {
339 for via_input in [
340 Via(vec![
341 ViaElement {
342 protocol: None,
343 version: ForwardedVersion::HTTP_10,
344 node_id: NodeId::try_from_str("fred").unwrap(),
345 },
346 ViaElement {
347 protocol: None,
348 version: ForwardedVersion::HTTP_11,
349 node_id: NodeId::try_from_str("p.example.net").unwrap(),
350 },
351 ]),
352 Via(vec![
353 ViaElement {
354 protocol: Some(ForwardedProtocol::HTTP),
355 version: ForwardedVersion::HTTP_11,
356 node_id: NodeId::try_from_str("proxy.example.re").unwrap(),
357 },
358 ViaElement {
359 protocol: None,
360 version: ForwardedVersion::HTTP_11,
361 node_id: NodeId::try_from_str("edge_1").unwrap(),
362 },
363 ]),
364 Via(vec![ViaElement {
365 protocol: None,
366 version: ForwardedVersion::HTTP_11,
367 node_id: NodeId::try_from_str(
368 "2e9b3ee4d534903f433e1ed8ea30e57a.cloudfront.net__CloudFront_",
369 )
370 .unwrap(),
371 }]),
372 ] {
373 let mut values = Vec::new();
374 via_input.encode(&mut values);
375 let via_output = Via::decode(&mut values.iter()).unwrap();
376 assert_eq!(via_input, via_output);
377 }
378 }
379}