rama_net/
proto.rs

1use smol_str::SmolStr;
2use std::cmp::min;
3use std::str::FromStr;
4
5use rama_core::error::{ErrorContext, OpaqueError};
6use rama_utils::macros::str::eq_ignore_ascii_case;
7
8#[cfg(feature = "http")]
9use rama_http_types::{Method, Scheme};
10
11#[cfg(feature = "http")]
12use tracing::{trace, warn};
13
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15/// Web protocols that are relevant to Rama.
16///
17/// Please [file an issue or open a PR][repo] if you need support for more protocols.
18/// When doing so please provide sufficient motivation and ensure
19/// it has no unintended consequences.
20///
21/// [repo]: https://github.com/plabayo/rama
22pub struct Protocol(ProtocolKind);
23
24impl Protocol {
25    #[cfg(feature = "http")]
26    pub fn maybe_from_uri_scheme_str_and_method(
27        s: Option<&Scheme>,
28        method: Option<&Method>,
29    ) -> Option<Self> {
30        s.map(|s| {
31            trace!("detected protocol from scheme");
32            let protocol: Protocol = s.into();
33            if method == Some(&Method::CONNECT) {
34                match protocol {
35                    Protocol::HTTP => {
36                        trace!("CONNECT request: upgrade HTTP => HTTPS");
37                        Protocol::HTTPS
38                    }
39                    Protocol::HTTPS => Protocol::HTTPS,
40                    Protocol::WS => {
41                        trace!("CONNECT request: upgrade WS => WSS");
42                        Protocol::WSS
43                    }
44                    Protocol::WSS => Protocol::WSS,
45                    other => {
46                        warn!(protocol = %other, "CONNECT request: unexpected protocol");
47                        other
48                    }
49                }
50            } else {
51                protocol
52            }
53        })
54    }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
58enum ProtocolKind {
59    /// The `http` protocol.
60    Http,
61    /// The `https` protocol.
62    Https,
63    /// The `ws` protocol.
64    ///
65    /// (Websocket over HTTP)
66    /// <https://datatracker.ietf.org/doc/html/rfc6455>
67    Ws,
68    /// The `wss` protocol.
69    ///
70    /// (Websocket over HTTPS)
71    /// <https://datatracker.ietf.org/doc/html/rfc6455>
72    Wss,
73    /// The `socks5` protocol.
74    ///
75    /// <https://datatracker.ietf.org/doc/html/rfc1928>
76    Socks5,
77    /// The `socks5h` protocol.
78    ///
79    /// Not official, but rather a convention that was introduced in version 4 of socks,
80    /// by curl and documented at <https://curl.se/libcurl/c/CURLOPT_PROXY.html>.
81    ///
82    /// The difference with [`Self::Socks5`] is that the proxy resolves the URL hostname.
83    Socks5h,
84    /// Custom protocol.
85    Custom(SmolStr),
86}
87
88const SCHEME_HTTP: &str = "http";
89const SCHEME_HTTPS: &str = "https";
90const SCHEME_SOCKS5: &str = "socks5";
91const SCHEME_SOCKS5H: &str = "socks5h";
92const SCHEME_WS: &str = "ws";
93const SCHEME_WSS: &str = "wss";
94
95impl Protocol {
96    /// `HTTP` protocol.
97    pub const HTTP: Self = Protocol(ProtocolKind::Http);
98
99    /// `HTTPS` protocol.
100    pub const HTTPS: Self = Protocol(ProtocolKind::Https);
101
102    /// `WS` protocol.
103    pub const WS: Self = Protocol(ProtocolKind::Ws);
104
105    /// `WSS` protocol.
106    pub const WSS: Self = Protocol(ProtocolKind::Wss);
107
108    /// `SOCKS5` protocol.
109    pub const SOCKS5: Self = Protocol(ProtocolKind::Socks5);
110
111    /// `SOCKS5H` protocol.
112    pub const SOCKS5H: Self = Protocol(ProtocolKind::Socks5h);
113
114    /// Creates a Protocol from a str a compile time.
115    ///
116    /// This function requires the static string to be a valid protocol.
117    ///
118    /// It is intended to be used to facilitate the compile-time creation of
119    /// custom Protocols, as known protocols are easier created by using the desired
120    /// variant directly.
121    ///
122    /// # Panics
123    ///
124    /// This function panics at **compile time** when the static string is not a valid protocol.
125    pub const fn from_static(s: &'static str) -> Self {
126        // NOTE: once unwrapping is possible in const we can piggy back on
127        // `try_to_convert_str_to_non_custom_protocol`
128
129        Protocol(if eq_ignore_ascii_case!(s, SCHEME_HTTPS) {
130            ProtocolKind::Https
131        } else if s.is_empty() || eq_ignore_ascii_case!(s, SCHEME_HTTP) {
132            ProtocolKind::Http
133        } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5) {
134            ProtocolKind::Socks5
135        } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5H) {
136            ProtocolKind::Socks5h
137        } else if eq_ignore_ascii_case!(s, SCHEME_WS) {
138            ProtocolKind::Ws
139        } else if eq_ignore_ascii_case!(s, SCHEME_WSS) {
140            ProtocolKind::Wss
141        } else if validate_scheme_str(s) {
142            ProtocolKind::Custom(SmolStr::new_static(s))
143        } else {
144            panic!("invalid static protocol str");
145        })
146    }
147
148    /// Returns `true` if this protocol is http(s).
149    pub fn is_http(&self) -> bool {
150        match &self.0 {
151            ProtocolKind::Http | ProtocolKind::Https => true,
152            ProtocolKind::Ws
153            | ProtocolKind::Wss
154            | ProtocolKind::Socks5
155            | ProtocolKind::Socks5h
156            | ProtocolKind::Custom(_) => false,
157        }
158    }
159
160    /// Returns `true` if this protocol is ws(s).
161    pub fn is_ws(&self) -> bool {
162        match &self.0 {
163            ProtocolKind::Ws | ProtocolKind::Wss => true,
164            ProtocolKind::Http
165            | ProtocolKind::Https
166            | ProtocolKind::Socks5
167            | ProtocolKind::Socks5h
168            | ProtocolKind::Custom(_) => false,
169        }
170    }
171
172    /// Returns `true` if this protocol is socks5.
173    pub fn is_socks5(&self) -> bool {
174        match &self.0 {
175            ProtocolKind::Socks5 => true,
176            ProtocolKind::Http
177            | ProtocolKind::Https
178            | ProtocolKind::Ws
179            | ProtocolKind::Wss
180            | ProtocolKind::Socks5h
181            | ProtocolKind::Custom(_) => false,
182        }
183    }
184
185    /// Returns `true` if this protocol is socks5h).
186    pub fn is_socks5h(&self) -> bool {
187        match &self.0 {
188            ProtocolKind::Socks5h => true,
189            ProtocolKind::Socks5
190            | ProtocolKind::Http
191            | ProtocolKind::Https
192            | ProtocolKind::Ws
193            | ProtocolKind::Wss
194            | ProtocolKind::Custom(_) => false,
195        }
196    }
197
198    /// Returns `true` if this protocol is "secure" by itself.
199    pub fn is_secure(&self) -> bool {
200        match &self.0 {
201            ProtocolKind::Https | ProtocolKind::Wss => true,
202            ProtocolKind::Ws
203            | ProtocolKind::Http
204            | ProtocolKind::Socks5
205            | ProtocolKind::Socks5h
206            | ProtocolKind::Custom(_) => false,
207        }
208    }
209
210    /// Returns the default port for this [`Protocol`]
211    pub fn default_port(&self) -> Option<u16> {
212        match &self.0 {
213            ProtocolKind::Https | ProtocolKind::Wss => Some(443),
214            ProtocolKind::Http | ProtocolKind::Ws => Some(80),
215            ProtocolKind::Socks5 | ProtocolKind::Socks5h => Some(1080),
216            ProtocolKind::Custom(_) => None,
217        }
218    }
219
220    /// Returns the [`Protocol`] as a string.
221    pub fn as_str(&self) -> &str {
222        match &self.0 {
223            ProtocolKind::Http => "http",
224            ProtocolKind::Https => "https",
225            ProtocolKind::Ws => "ws",
226            ProtocolKind::Wss => "wss",
227            ProtocolKind::Socks5 => "socks5",
228            ProtocolKind::Socks5h => "socks5h",
229            ProtocolKind::Custom(s) => s.as_ref(),
230        }
231    }
232}
233
234rama_utils::macros::error::static_str_error! {
235    #[doc = "invalid protocol string"]
236    pub struct InvalidProtocolStr;
237}
238
239fn try_to_convert_str_to_non_custom_protocol(
240    s: &str,
241) -> Result<Option<Protocol>, InvalidProtocolStr> {
242    Ok(Some(Protocol(if eq_ignore_ascii_case!(s, SCHEME_HTTPS) {
243        ProtocolKind::Https
244    } else if s.is_empty() || eq_ignore_ascii_case!(s, SCHEME_HTTP) {
245        ProtocolKind::Http
246    } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5) {
247        ProtocolKind::Socks5
248    } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5H) {
249        ProtocolKind::Socks5h
250    } else if eq_ignore_ascii_case!(s, SCHEME_WS) {
251        ProtocolKind::Ws
252    } else if eq_ignore_ascii_case!(s, SCHEME_WSS) {
253        ProtocolKind::Wss
254    } else if validate_scheme_str(s) {
255        return Ok(None);
256    } else {
257        return Err(InvalidProtocolStr);
258    })))
259}
260
261impl TryFrom<&str> for Protocol {
262    type Error = InvalidProtocolStr;
263
264    fn try_from(s: &str) -> Result<Self, Self::Error> {
265        Ok(try_to_convert_str_to_non_custom_protocol(s)?
266            .unwrap_or_else(|| Protocol(ProtocolKind::Custom(SmolStr::new_inline(s)))))
267    }
268}
269
270impl TryFrom<String> for Protocol {
271    type Error = InvalidProtocolStr;
272
273    fn try_from(s: String) -> Result<Self, Self::Error> {
274        Ok(try_to_convert_str_to_non_custom_protocol(&s)?
275            .unwrap_or(Protocol(ProtocolKind::Custom(SmolStr::new(s)))))
276    }
277}
278
279impl TryFrom<&String> for Protocol {
280    type Error = InvalidProtocolStr;
281
282    fn try_from(s: &String) -> Result<Self, Self::Error> {
283        Ok(try_to_convert_str_to_non_custom_protocol(s)?
284            .unwrap_or_else(|| Protocol(ProtocolKind::Custom(SmolStr::new(s)))))
285    }
286}
287
288impl FromStr for Protocol {
289    type Err = InvalidProtocolStr;
290
291    fn from_str(s: &str) -> Result<Self, Self::Err> {
292        s.try_into()
293    }
294}
295
296#[cfg(feature = "http")]
297impl From<Scheme> for Protocol {
298    #[inline]
299    fn from(s: Scheme) -> Self {
300        s.as_str()
301            .try_into()
302            .expect("http crate Scheme is pre-validated by promise")
303    }
304}
305
306#[cfg(feature = "http")]
307impl From<&Scheme> for Protocol {
308    fn from(s: &Scheme) -> Self {
309        s.as_str()
310            .try_into()
311            .expect("http crate Scheme is pre-validated by promise")
312    }
313}
314
315impl PartialEq<str> for Protocol {
316    fn eq(&self, other: &str) -> bool {
317        match &self.0 {
318            ProtocolKind::Https => other.eq_ignore_ascii_case(SCHEME_HTTPS),
319            ProtocolKind::Http => other.eq_ignore_ascii_case(SCHEME_HTTP) || other.is_empty(),
320            ProtocolKind::Socks5 => other.eq_ignore_ascii_case(SCHEME_SOCKS5),
321            ProtocolKind::Socks5h => other.eq_ignore_ascii_case(SCHEME_SOCKS5H),
322            ProtocolKind::Ws => other.eq_ignore_ascii_case("ws"),
323            ProtocolKind::Wss => other.eq_ignore_ascii_case("wss"),
324            ProtocolKind::Custom(s) => other.eq_ignore_ascii_case(s),
325        }
326    }
327}
328
329impl PartialEq<String> for Protocol {
330    fn eq(&self, other: &String) -> bool {
331        self == other.as_str()
332    }
333}
334
335impl PartialEq<&str> for Protocol {
336    fn eq(&self, other: &&str) -> bool {
337        self == *other
338    }
339}
340
341impl PartialEq<Protocol> for str {
342    fn eq(&self, other: &Protocol) -> bool {
343        other == self
344    }
345}
346
347impl PartialEq<Protocol> for String {
348    fn eq(&self, other: &Protocol) -> bool {
349        other == self.as_str()
350    }
351}
352
353impl PartialEq<Protocol> for &str {
354    fn eq(&self, other: &Protocol) -> bool {
355        other == *self
356    }
357}
358
359impl std::fmt::Display for Protocol {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        self.as_str().fmt(f)
362    }
363}
364
365pub(crate) fn try_to_extract_protocol_from_uri_scheme(
366    s: &[u8],
367) -> Result<(Option<Protocol>, usize), OpaqueError> {
368    if s.is_empty() {
369        return Err(OpaqueError::from_display("empty uri contains no scheme"));
370    }
371
372    for i in 0..min(s.len(), 512) {
373        let b = s[i];
374
375        if b == b':' {
376            // Not enough data remaining
377            if s.len() < i + 3 {
378                break;
379            }
380
381            // Not a scheme
382            if &s[i + 1..i + 3] != b"//" {
383                break;
384            }
385
386            let str =
387                std::str::from_utf8(&s[..i]).context("interpret scheme bytes as utf-8 str")?;
388            let protocol = str
389                .try_into()
390                .context("parse scheme utf-8 str as protocol")?;
391            return Ok((Some(protocol), i + 3));
392        }
393    }
394
395    Ok((None, 0))
396}
397
398#[inline]
399const fn validate_scheme_str(s: &str) -> bool {
400    validate_scheme_slice(s.as_bytes())
401}
402
403const fn validate_scheme_slice(s: &[u8]) -> bool {
404    if s.is_empty() || s.len() > MAX_SCHEME_LEN {
405        return false;
406    }
407
408    let mut i = 0;
409    while i < s.len() {
410        if SCHEME_CHARS[s[i] as usize] == 0 {
411            return false;
412        }
413        i += 1;
414    }
415    true
416}
417
418// Require the scheme to not be too long in order to enable further
419// optimizations later.
420const MAX_SCHEME_LEN: usize = 64;
421
422// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
423//
424// SCHEME_CHARS is a table of valid characters in the scheme part of a URI.  An
425// entry in the table is 0 for invalid characters. For valid characters the
426// entry is itself (i.e.  the entry for 43 is b'+' because b'+' == 43u8). An
427// important characteristic of this table is that all entries above 127 are
428// invalid. This makes all of the valid entries a valid single-byte UTF-8 code
429// point. This means that a slice of such valid entries is valid UTF-8.
430#[rustfmt::skip]
431const SCHEME_CHARS: [u8; 256] = [
432    //  0      1      2      3      4      5      6      7      8      9
433        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, //   x
434        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, //  1x
435        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, //  2x
436        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, //  3x
437        0,     0,     0,  b'+',     0,  b'-',  b'.',     0,  b'0',  b'1', //  4x
438     b'2',  b'3',  b'4',  b'5',  b'6',  b'7',  b'8',  b'9',     0,     0, //  5x
439        0,     0,     0,     0,     0,  b'A',  b'B',  b'C',  b'D',  b'E', //  6x
440     b'F',  b'G',  b'H',  b'I',  b'J',  b'K',  b'L',  b'M',  b'N',  b'O', //  7x
441     b'P',  b'Q',  b'R',  b'S',  b'T',  b'U',  b'V',  b'W',  b'X',  b'Y', //  8x
442     b'Z',     0,     0,     0,     0,     0,     0,  b'a',  b'b',  b'c', //  9x
443     b'd',  b'e',  b'f',  b'g',  b'h',  b'i',  b'j',  b'k',  b'l',  b'm', // 10x
444     b'n',  b'o',  b'p',  b'q',  b'r',  b's',  b't',  b'u',  b'v',  b'w', // 11x
445     b'x',  b'y',  b'z',     0,     0,     0,     0,     0,     0,     0, // 12x
446        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 13x
447        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 14x
448        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 15x
449        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 16x
450        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 17x
451        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 18x
452        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 19x
453        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 20x
454        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 21x
455        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 22x
456        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 23x
457        0,     0,     0,     0,     0,     0,     0,     0,     0,     0, // 24x
458        0,     0,     0,     0,     0,     0                              // 25x
459];
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_from_str() {
467        assert_eq!("http".parse(), Ok(Protocol::HTTP));
468        assert_eq!("".parse(), Ok(Protocol::HTTP));
469        assert_eq!("https".parse(), Ok(Protocol::HTTPS));
470        assert_eq!("ws".parse(), Ok(Protocol::WS));
471        assert_eq!("wss".parse(), Ok(Protocol::WSS));
472        assert_eq!("socks5".parse(), Ok(Protocol::SOCKS5));
473        assert_eq!("socks5h".parse(), Ok(Protocol::SOCKS5H));
474        assert_eq!("custom".parse(), Ok(Protocol::from_static("custom")));
475    }
476
477    #[cfg(feature = "http")]
478    #[test]
479    fn test_from_http_scheme() {
480        for s in [
481            "http", "https", "ws", "wss", "socks5", "socks5h", "", "custom",
482        ]
483        .iter()
484        {
485            let uri =
486                rama_http_types::Uri::from_str(format!("{}://example.com", s).as_str()).unwrap();
487            assert_eq!(Protocol::from(uri.scheme().unwrap()), *s);
488        }
489    }
490
491    #[test]
492    fn test_scheme_is_secure() {
493        assert!(!Protocol::HTTP.is_secure());
494        assert!(Protocol::HTTPS.is_secure());
495        assert!(!Protocol::SOCKS5.is_secure());
496        assert!(!Protocol::SOCKS5H.is_secure());
497        assert!(!Protocol::WS.is_secure());
498        assert!(Protocol::WSS.is_secure());
499        assert!(!Protocol::from_static("custom").is_secure());
500    }
501
502    #[test]
503    fn test_try_to_extract_protocol_from_uri_scheme() {
504        for (s, expected) in [
505            ("", None),
506            ("http://example.com", Some((Some(Protocol::HTTP), 7))),
507            ("https://example.com", Some((Some(Protocol::HTTPS), 8))),
508            ("ws://example.com", Some((Some(Protocol::WS), 5))),
509            ("wss://example.com", Some((Some(Protocol::WSS), 6))),
510            ("socks5://example.com", Some((Some(Protocol::SOCKS5), 9))),
511            ("socks5h://example.com", Some((Some(Protocol::SOCKS5H), 10))),
512            (
513                "custom://example.com",
514                Some((Some(Protocol::from_static("custom")), 9)),
515            ),
516            (" http://example.com", None),
517            ("example.com", Some((None, 0))),
518            ("127.0.0.1", Some((None, 0))),
519            ("127.0.0.1:8080", Some((None, 0))),
520            (
521                "longlonglongwaytoolongforsomethingusefulorvaliddontyouthinkmydearreader://example.com",
522                None,
523            ),
524        ] {
525            let result = try_to_extract_protocol_from_uri_scheme(s.as_bytes());
526            match expected {
527                Some(t) => match result {
528                    Err(err) => panic!("unexpected err: {err} (case: {s}"),
529                    Ok(p) => assert_eq!(t, p, "case: {}", s),
530                },
531                None => assert!(result.is_err(), "case: {}, result: {:?}", s, result),
532            }
533        }
534    }
535}