rama_net/forwarded/
proto.rs

1use crate::Protocol;
2use rama_utils::macros::{error::static_str_error, str::eq_ignore_ascii_case};
3use std::str::FromStr;
4
5#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
6/// Protocols that were forwarded.
7///
8/// These are a subset of [`Protocol`].
9///
10/// Please [file an issue or open a PR][repo] if you need support for more protocols.
11/// When doing so please provide sufficient motivation and ensure
12/// it has no unintended consequences.
13///
14/// [repo]: https://github.com/plabayo/rama
15pub struct ForwardedProtocol(ProtocolKind);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
18enum ProtocolKind {
19    /// The `http` protocol.
20    Http,
21    /// The `https` protocol.
22    Https,
23}
24
25const HTTP_STR: &str = "http";
26const HTTPS_STR: &str = "https";
27
28impl ForwardedProtocol {
29    /// `HTTP` protocol.
30    pub const HTTP: ForwardedProtocol = ForwardedProtocol(ProtocolKind::Http);
31
32    /// `HTTPS` protocol.
33    pub const HTTPS: ForwardedProtocol = ForwardedProtocol(ProtocolKind::Https);
34
35    /// Returns `true` if this protocol is http(s).
36    pub fn is_http(&self) -> bool {
37        match &self.0 {
38            ProtocolKind::Http | ProtocolKind::Https => true,
39        }
40    }
41
42    /// Returns `true` if this protocol is "secure" by itself.
43    pub fn is_secure(&self) -> bool {
44        match self.0 {
45            ProtocolKind::Https => true,
46            ProtocolKind::Http => false,
47        }
48    }
49
50    /// Returns the scheme str for this protocol.
51    pub fn as_scheme(&self) -> &str {
52        match &self.0 {
53            ProtocolKind::Https => HTTPS_STR,
54            ProtocolKind::Http => HTTP_STR,
55        }
56    }
57
58    #[inline]
59    /// Consumes the protocol and returns a [`Protocol`].
60    pub fn into_protocol(self) -> Protocol {
61        self.into()
62    }
63
64    /// Returns the [`ForwardedProtocol`] as a string.
65    pub fn as_str(&self) -> &str {
66        match &self.0 {
67            ProtocolKind::Https => HTTPS_STR,
68            ProtocolKind::Http => HTTP_STR,
69        }
70    }
71}
72
73impl From<ForwardedProtocol> for Protocol {
74    fn from(p: ForwardedProtocol) -> Self {
75        match p.0 {
76            ProtocolKind::Https => Protocol::HTTPS,
77            ProtocolKind::Http => Protocol::HTTP,
78        }
79    }
80}
81
82static_str_error! {
83    #[doc = "unknown protocol"]
84    pub struct UnknownProtocol;
85}
86
87impl TryFrom<Protocol> for ForwardedProtocol {
88    type Error = UnknownProtocol;
89
90    fn try_from(p: Protocol) -> Result<Self, Self::Error> {
91        if p.is_http() {
92            if p.is_secure() {
93                Ok(ForwardedProtocol(ProtocolKind::Https))
94            } else {
95                Ok(ForwardedProtocol(ProtocolKind::Http))
96            }
97        } else {
98            Err(UnknownProtocol)
99        }
100    }
101}
102
103impl TryFrom<&Protocol> for ForwardedProtocol {
104    type Error = UnknownProtocol;
105
106    fn try_from(p: &Protocol) -> Result<Self, Self::Error> {
107        if p.is_http() {
108            if p.is_secure() {
109                Ok(ForwardedProtocol(ProtocolKind::Https))
110            } else {
111                Ok(ForwardedProtocol(ProtocolKind::Http))
112            }
113        } else {
114            Err(UnknownProtocol)
115        }
116    }
117}
118
119static_str_error! {
120    #[doc = "invalid forwarded protocol string"]
121    pub struct InvalidProtocolStr;
122}
123
124impl TryFrom<&str> for ForwardedProtocol {
125    type Error = InvalidProtocolStr;
126
127    fn try_from(s: &str) -> Result<Self, Self::Error> {
128        if eq_ignore_ascii_case!(s, HTTP_STR) {
129            Ok(ForwardedProtocol(ProtocolKind::Http))
130        } else if eq_ignore_ascii_case!(s, HTTPS_STR) {
131            Ok(ForwardedProtocol(ProtocolKind::Https))
132        } else {
133            Err(InvalidProtocolStr)
134        }
135    }
136}
137
138impl TryFrom<String> for ForwardedProtocol {
139    type Error = InvalidProtocolStr;
140
141    fn try_from(s: String) -> Result<Self, Self::Error> {
142        s.as_str().try_into()
143    }
144}
145
146impl TryFrom<&String> for ForwardedProtocol {
147    type Error = InvalidProtocolStr;
148
149    fn try_from(s: &String) -> Result<Self, Self::Error> {
150        s.as_str().try_into()
151    }
152}
153
154impl FromStr for ForwardedProtocol {
155    type Err = InvalidProtocolStr;
156
157    fn from_str(s: &str) -> Result<Self, Self::Err> {
158        s.try_into()
159    }
160}
161
162impl PartialEq<str> for ForwardedProtocol {
163    fn eq(&self, other: &str) -> bool {
164        match &self.0 {
165            ProtocolKind::Https => other.eq_ignore_ascii_case(HTTPS_STR),
166            ProtocolKind::Http => other.eq_ignore_ascii_case(HTTP_STR) || other.is_empty(),
167        }
168    }
169}
170
171impl PartialEq<String> for ForwardedProtocol {
172    fn eq(&self, other: &String) -> bool {
173        self == other.as_str()
174    }
175}
176
177impl PartialEq<&str> for ForwardedProtocol {
178    fn eq(&self, other: &&str) -> bool {
179        self == *other
180    }
181}
182
183impl PartialEq<ForwardedProtocol> for str {
184    fn eq(&self, other: &ForwardedProtocol) -> bool {
185        other == self
186    }
187}
188
189impl PartialEq<ForwardedProtocol> for String {
190    fn eq(&self, other: &ForwardedProtocol) -> bool {
191        other == self.as_str()
192    }
193}
194
195impl PartialEq<ForwardedProtocol> for &str {
196    fn eq(&self, other: &ForwardedProtocol) -> bool {
197        other == *self
198    }
199}
200
201impl std::fmt::Display for ForwardedProtocol {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        self.as_scheme().fmt(f)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_protocol_from_str() {
213        assert_eq!("http".parse(), Ok(ForwardedProtocol::HTTP));
214        assert_eq!("https".parse(), Ok(ForwardedProtocol::HTTPS));
215    }
216
217    #[test]
218    fn test_protocol_secure() {
219        assert!(!ForwardedProtocol::HTTP.is_secure());
220        assert!(ForwardedProtocol::HTTPS.is_secure());
221    }
222}