1use smol_str::SmolStr;
2use std::cmp::min;
3use std::str::FromStr;
4
5use rama_core::error::{ErrorContext, OpaqueError};
6use rama_utils::macros::str::eq_ignore_ascii_case;
7
8#[cfg(feature = "http")]
9use rama_http_types::{Method, Scheme};
10
11#[cfg(feature = "http")]
12use tracing::{trace, warn};
13
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Protocol(ProtocolKind);
23
24impl Protocol {
25 #[cfg(feature = "http")]
26 pub fn maybe_from_uri_scheme_str_and_method(
27 s: Option<&Scheme>,
28 method: Option<&Method>,
29 ) -> Option<Self> {
30 s.map(|s| {
31 trace!("detected protocol from scheme");
32 let protocol: Protocol = s.into();
33 if method == Some(&Method::CONNECT) {
34 match protocol {
35 Protocol::HTTP => {
36 trace!("CONNECT request: upgrade HTTP => HTTPS");
37 Protocol::HTTPS
38 }
39 Protocol::HTTPS => Protocol::HTTPS,
40 Protocol::WS => {
41 trace!("CONNECT request: upgrade WS => WSS");
42 Protocol::WSS
43 }
44 Protocol::WSS => Protocol::WSS,
45 other => {
46 warn!(protocol = %other, "CONNECT request: unexpected protocol");
47 other
48 }
49 }
50 } else {
51 protocol
52 }
53 })
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
58enum ProtocolKind {
59 Http,
61 Https,
63 Ws,
68 Wss,
73 Socks5,
77 Socks5h,
84 Custom(SmolStr),
86}
87
88const SCHEME_HTTP: &str = "http";
89const SCHEME_HTTPS: &str = "https";
90const SCHEME_SOCKS5: &str = "socks5";
91const SCHEME_SOCKS5H: &str = "socks5h";
92const SCHEME_WS: &str = "ws";
93const SCHEME_WSS: &str = "wss";
94
95impl Protocol {
96 pub const HTTP: Self = Protocol(ProtocolKind::Http);
98
99 pub const HTTPS: Self = Protocol(ProtocolKind::Https);
101
102 pub const WS: Self = Protocol(ProtocolKind::Ws);
104
105 pub const WSS: Self = Protocol(ProtocolKind::Wss);
107
108 pub const SOCKS5: Self = Protocol(ProtocolKind::Socks5);
110
111 pub const SOCKS5H: Self = Protocol(ProtocolKind::Socks5h);
113
114 pub const fn from_static(s: &'static str) -> Self {
126 Protocol(if eq_ignore_ascii_case!(s, SCHEME_HTTPS) {
130 ProtocolKind::Https
131 } else if s.is_empty() || eq_ignore_ascii_case!(s, SCHEME_HTTP) {
132 ProtocolKind::Http
133 } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5) {
134 ProtocolKind::Socks5
135 } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5H) {
136 ProtocolKind::Socks5h
137 } else if eq_ignore_ascii_case!(s, SCHEME_WS) {
138 ProtocolKind::Ws
139 } else if eq_ignore_ascii_case!(s, SCHEME_WSS) {
140 ProtocolKind::Wss
141 } else if validate_scheme_str(s) {
142 ProtocolKind::Custom(SmolStr::new_static(s))
143 } else {
144 panic!("invalid static protocol str");
145 })
146 }
147
148 pub fn is_http(&self) -> bool {
150 match &self.0 {
151 ProtocolKind::Http | ProtocolKind::Https => true,
152 ProtocolKind::Ws
153 | ProtocolKind::Wss
154 | ProtocolKind::Socks5
155 | ProtocolKind::Socks5h
156 | ProtocolKind::Custom(_) => false,
157 }
158 }
159
160 pub fn is_ws(&self) -> bool {
162 match &self.0 {
163 ProtocolKind::Ws | ProtocolKind::Wss => true,
164 ProtocolKind::Http
165 | ProtocolKind::Https
166 | ProtocolKind::Socks5
167 | ProtocolKind::Socks5h
168 | ProtocolKind::Custom(_) => false,
169 }
170 }
171
172 pub fn is_socks5(&self) -> bool {
174 match &self.0 {
175 ProtocolKind::Socks5 => true,
176 ProtocolKind::Http
177 | ProtocolKind::Https
178 | ProtocolKind::Ws
179 | ProtocolKind::Wss
180 | ProtocolKind::Socks5h
181 | ProtocolKind::Custom(_) => false,
182 }
183 }
184
185 pub fn is_socks5h(&self) -> bool {
187 match &self.0 {
188 ProtocolKind::Socks5h => true,
189 ProtocolKind::Socks5
190 | ProtocolKind::Http
191 | ProtocolKind::Https
192 | ProtocolKind::Ws
193 | ProtocolKind::Wss
194 | ProtocolKind::Custom(_) => false,
195 }
196 }
197
198 pub fn is_secure(&self) -> bool {
200 match &self.0 {
201 ProtocolKind::Https | ProtocolKind::Wss => true,
202 ProtocolKind::Ws
203 | ProtocolKind::Http
204 | ProtocolKind::Socks5
205 | ProtocolKind::Socks5h
206 | ProtocolKind::Custom(_) => false,
207 }
208 }
209
210 pub fn default_port(&self) -> Option<u16> {
212 match &self.0 {
213 ProtocolKind::Https | ProtocolKind::Wss => Some(443),
214 ProtocolKind::Http | ProtocolKind::Ws => Some(80),
215 ProtocolKind::Socks5 | ProtocolKind::Socks5h => Some(1080),
216 ProtocolKind::Custom(_) => None,
217 }
218 }
219
220 pub fn as_str(&self) -> &str {
222 match &self.0 {
223 ProtocolKind::Http => "http",
224 ProtocolKind::Https => "https",
225 ProtocolKind::Ws => "ws",
226 ProtocolKind::Wss => "wss",
227 ProtocolKind::Socks5 => "socks5",
228 ProtocolKind::Socks5h => "socks5h",
229 ProtocolKind::Custom(s) => s.as_ref(),
230 }
231 }
232}
233
234rama_utils::macros::error::static_str_error! {
235 #[doc = "invalid protocol string"]
236 pub struct InvalidProtocolStr;
237}
238
239fn try_to_convert_str_to_non_custom_protocol(
240 s: &str,
241) -> Result<Option<Protocol>, InvalidProtocolStr> {
242 Ok(Some(Protocol(if eq_ignore_ascii_case!(s, SCHEME_HTTPS) {
243 ProtocolKind::Https
244 } else if s.is_empty() || eq_ignore_ascii_case!(s, SCHEME_HTTP) {
245 ProtocolKind::Http
246 } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5) {
247 ProtocolKind::Socks5
248 } else if eq_ignore_ascii_case!(s, SCHEME_SOCKS5H) {
249 ProtocolKind::Socks5h
250 } else if eq_ignore_ascii_case!(s, SCHEME_WS) {
251 ProtocolKind::Ws
252 } else if eq_ignore_ascii_case!(s, SCHEME_WSS) {
253 ProtocolKind::Wss
254 } else if validate_scheme_str(s) {
255 return Ok(None);
256 } else {
257 return Err(InvalidProtocolStr);
258 })))
259}
260
261impl TryFrom<&str> for Protocol {
262 type Error = InvalidProtocolStr;
263
264 fn try_from(s: &str) -> Result<Self, Self::Error> {
265 Ok(try_to_convert_str_to_non_custom_protocol(s)?
266 .unwrap_or_else(|| Protocol(ProtocolKind::Custom(SmolStr::new_inline(s)))))
267 }
268}
269
270impl TryFrom<String> for Protocol {
271 type Error = InvalidProtocolStr;
272
273 fn try_from(s: String) -> Result<Self, Self::Error> {
274 Ok(try_to_convert_str_to_non_custom_protocol(&s)?
275 .unwrap_or(Protocol(ProtocolKind::Custom(SmolStr::new(s)))))
276 }
277}
278
279impl TryFrom<&String> for Protocol {
280 type Error = InvalidProtocolStr;
281
282 fn try_from(s: &String) -> Result<Self, Self::Error> {
283 Ok(try_to_convert_str_to_non_custom_protocol(s)?
284 .unwrap_or_else(|| Protocol(ProtocolKind::Custom(SmolStr::new(s)))))
285 }
286}
287
288impl FromStr for Protocol {
289 type Err = InvalidProtocolStr;
290
291 fn from_str(s: &str) -> Result<Self, Self::Err> {
292 s.try_into()
293 }
294}
295
296#[cfg(feature = "http")]
297impl From<Scheme> for Protocol {
298 #[inline]
299 fn from(s: Scheme) -> Self {
300 s.as_str()
301 .try_into()
302 .expect("http crate Scheme is pre-validated by promise")
303 }
304}
305
306#[cfg(feature = "http")]
307impl From<&Scheme> for Protocol {
308 fn from(s: &Scheme) -> Self {
309 s.as_str()
310 .try_into()
311 .expect("http crate Scheme is pre-validated by promise")
312 }
313}
314
315impl PartialEq<str> for Protocol {
316 fn eq(&self, other: &str) -> bool {
317 match &self.0 {
318 ProtocolKind::Https => other.eq_ignore_ascii_case(SCHEME_HTTPS),
319 ProtocolKind::Http => other.eq_ignore_ascii_case(SCHEME_HTTP) || other.is_empty(),
320 ProtocolKind::Socks5 => other.eq_ignore_ascii_case(SCHEME_SOCKS5),
321 ProtocolKind::Socks5h => other.eq_ignore_ascii_case(SCHEME_SOCKS5H),
322 ProtocolKind::Ws => other.eq_ignore_ascii_case("ws"),
323 ProtocolKind::Wss => other.eq_ignore_ascii_case("wss"),
324 ProtocolKind::Custom(s) => other.eq_ignore_ascii_case(s),
325 }
326 }
327}
328
329impl PartialEq<String> for Protocol {
330 fn eq(&self, other: &String) -> bool {
331 self == other.as_str()
332 }
333}
334
335impl PartialEq<&str> for Protocol {
336 fn eq(&self, other: &&str) -> bool {
337 self == *other
338 }
339}
340
341impl PartialEq<Protocol> for str {
342 fn eq(&self, other: &Protocol) -> bool {
343 other == self
344 }
345}
346
347impl PartialEq<Protocol> for String {
348 fn eq(&self, other: &Protocol) -> bool {
349 other == self.as_str()
350 }
351}
352
353impl PartialEq<Protocol> for &str {
354 fn eq(&self, other: &Protocol) -> bool {
355 other == *self
356 }
357}
358
359impl std::fmt::Display for Protocol {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 self.as_str().fmt(f)
362 }
363}
364
365pub(crate) fn try_to_extract_protocol_from_uri_scheme(
366 s: &[u8],
367) -> Result<(Option<Protocol>, usize), OpaqueError> {
368 if s.is_empty() {
369 return Err(OpaqueError::from_display("empty uri contains no scheme"));
370 }
371
372 for i in 0..min(s.len(), 512) {
373 let b = s[i];
374
375 if b == b':' {
376 if s.len() < i + 3 {
378 break;
379 }
380
381 if &s[i + 1..i + 3] != b"//" {
383 break;
384 }
385
386 let str =
387 std::str::from_utf8(&s[..i]).context("interpret scheme bytes as utf-8 str")?;
388 let protocol = str
389 .try_into()
390 .context("parse scheme utf-8 str as protocol")?;
391 return Ok((Some(protocol), i + 3));
392 }
393 }
394
395 Ok((None, 0))
396}
397
398#[inline]
399const fn validate_scheme_str(s: &str) -> bool {
400 validate_scheme_slice(s.as_bytes())
401}
402
403const fn validate_scheme_slice(s: &[u8]) -> bool {
404 if s.is_empty() || s.len() > MAX_SCHEME_LEN {
405 return false;
406 }
407
408 let mut i = 0;
409 while i < s.len() {
410 if SCHEME_CHARS[s[i] as usize] == 0 {
411 return false;
412 }
413 i += 1;
414 }
415 true
416}
417
418const MAX_SCHEME_LEN: usize = 64;
421
422#[rustfmt::skip]
431const SCHEME_CHARS: [u8; 256] = [
432 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, b'+', 0, b'-', b'.', 0, b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', 0, 0, 0, 0, 0, 0, 0, b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', 0, 0, 0, 0, 0, 0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ];
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_from_str() {
467 assert_eq!("http".parse(), Ok(Protocol::HTTP));
468 assert_eq!("".parse(), Ok(Protocol::HTTP));
469 assert_eq!("https".parse(), Ok(Protocol::HTTPS));
470 assert_eq!("ws".parse(), Ok(Protocol::WS));
471 assert_eq!("wss".parse(), Ok(Protocol::WSS));
472 assert_eq!("socks5".parse(), Ok(Protocol::SOCKS5));
473 assert_eq!("socks5h".parse(), Ok(Protocol::SOCKS5H));
474 assert_eq!("custom".parse(), Ok(Protocol::from_static("custom")));
475 }
476
477 #[cfg(feature = "http")]
478 #[test]
479 fn test_from_http_scheme() {
480 for s in [
481 "http", "https", "ws", "wss", "socks5", "socks5h", "", "custom",
482 ]
483 .iter()
484 {
485 let uri =
486 rama_http_types::Uri::from_str(format!("{}://example.com", s).as_str()).unwrap();
487 assert_eq!(Protocol::from(uri.scheme().unwrap()), *s);
488 }
489 }
490
491 #[test]
492 fn test_scheme_is_secure() {
493 assert!(!Protocol::HTTP.is_secure());
494 assert!(Protocol::HTTPS.is_secure());
495 assert!(!Protocol::SOCKS5.is_secure());
496 assert!(!Protocol::SOCKS5H.is_secure());
497 assert!(!Protocol::WS.is_secure());
498 assert!(Protocol::WSS.is_secure());
499 assert!(!Protocol::from_static("custom").is_secure());
500 }
501
502 #[test]
503 fn test_try_to_extract_protocol_from_uri_scheme() {
504 for (s, expected) in [
505 ("", None),
506 ("http://example.com", Some((Some(Protocol::HTTP), 7))),
507 ("https://example.com", Some((Some(Protocol::HTTPS), 8))),
508 ("ws://example.com", Some((Some(Protocol::WS), 5))),
509 ("wss://example.com", Some((Some(Protocol::WSS), 6))),
510 ("socks5://example.com", Some((Some(Protocol::SOCKS5), 9))),
511 ("socks5h://example.com", Some((Some(Protocol::SOCKS5H), 10))),
512 (
513 "custom://example.com",
514 Some((Some(Protocol::from_static("custom")), 9)),
515 ),
516 (" http://example.com", None),
517 ("example.com", Some((None, 0))),
518 ("127.0.0.1", Some((None, 0))),
519 ("127.0.0.1:8080", Some((None, 0))),
520 (
521 "longlonglongwaytoolongforsomethingusefulorvaliddontyouthinkmydearreader://example.com",
522 None,
523 ),
524 ] {
525 let result = try_to_extract_protocol_from_uri_scheme(s.as_bytes());
526 match expected {
527 Some(t) => match result {
528 Err(err) => panic!("unexpected err: {err} (case: {s}"),
529 Ok(p) => assert_eq!(t, p, "case: {}", s),
530 },
531 None => assert!(result.is_err(), "case: {}, result: {:?}", s, result),
532 }
533 }
534 }
535}