rama_net/forwarded/
proto.rs1use crate::Protocol;
2use rama_utils::macros::{error::static_str_error, str::eq_ignore_ascii_case};
3use std::str::FromStr;
4
5#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
6pub struct ForwardedProtocol(ProtocolKind);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
18enum ProtocolKind {
19 Http,
21 Https,
23}
24
25const HTTP_STR: &str = "http";
26const HTTPS_STR: &str = "https";
27
28impl ForwardedProtocol {
29 pub const HTTP: ForwardedProtocol = ForwardedProtocol(ProtocolKind::Http);
31
32 pub const HTTPS: ForwardedProtocol = ForwardedProtocol(ProtocolKind::Https);
34
35 pub fn is_http(&self) -> bool {
37 match &self.0 {
38 ProtocolKind::Http | ProtocolKind::Https => true,
39 }
40 }
41
42 pub fn is_secure(&self) -> bool {
44 match self.0 {
45 ProtocolKind::Https => true,
46 ProtocolKind::Http => false,
47 }
48 }
49
50 pub fn as_scheme(&self) -> &str {
52 match &self.0 {
53 ProtocolKind::Https => HTTPS_STR,
54 ProtocolKind::Http => HTTP_STR,
55 }
56 }
57
58 #[inline]
59 pub fn into_protocol(self) -> Protocol {
61 self.into()
62 }
63
64 pub fn as_str(&self) -> &str {
66 match &self.0 {
67 ProtocolKind::Https => HTTPS_STR,
68 ProtocolKind::Http => HTTP_STR,
69 }
70 }
71}
72
73impl From<ForwardedProtocol> for Protocol {
74 fn from(p: ForwardedProtocol) -> Self {
75 match p.0 {
76 ProtocolKind::Https => Protocol::HTTPS,
77 ProtocolKind::Http => Protocol::HTTP,
78 }
79 }
80}
81
82static_str_error! {
83 #[doc = "unknown protocol"]
84 pub struct UnknownProtocol;
85}
86
87impl TryFrom<Protocol> for ForwardedProtocol {
88 type Error = UnknownProtocol;
89
90 fn try_from(p: Protocol) -> Result<Self, Self::Error> {
91 if p.is_http() {
92 if p.is_secure() {
93 Ok(ForwardedProtocol(ProtocolKind::Https))
94 } else {
95 Ok(ForwardedProtocol(ProtocolKind::Http))
96 }
97 } else {
98 Err(UnknownProtocol)
99 }
100 }
101}
102
103impl TryFrom<&Protocol> for ForwardedProtocol {
104 type Error = UnknownProtocol;
105
106 fn try_from(p: &Protocol) -> Result<Self, Self::Error> {
107 if p.is_http() {
108 if p.is_secure() {
109 Ok(ForwardedProtocol(ProtocolKind::Https))
110 } else {
111 Ok(ForwardedProtocol(ProtocolKind::Http))
112 }
113 } else {
114 Err(UnknownProtocol)
115 }
116 }
117}
118
119static_str_error! {
120 #[doc = "invalid forwarded protocol string"]
121 pub struct InvalidProtocolStr;
122}
123
124impl TryFrom<&str> for ForwardedProtocol {
125 type Error = InvalidProtocolStr;
126
127 fn try_from(s: &str) -> Result<Self, Self::Error> {
128 if eq_ignore_ascii_case!(s, HTTP_STR) {
129 Ok(ForwardedProtocol(ProtocolKind::Http))
130 } else if eq_ignore_ascii_case!(s, HTTPS_STR) {
131 Ok(ForwardedProtocol(ProtocolKind::Https))
132 } else {
133 Err(InvalidProtocolStr)
134 }
135 }
136}
137
138impl TryFrom<String> for ForwardedProtocol {
139 type Error = InvalidProtocolStr;
140
141 fn try_from(s: String) -> Result<Self, Self::Error> {
142 s.as_str().try_into()
143 }
144}
145
146impl TryFrom<&String> for ForwardedProtocol {
147 type Error = InvalidProtocolStr;
148
149 fn try_from(s: &String) -> Result<Self, Self::Error> {
150 s.as_str().try_into()
151 }
152}
153
154impl FromStr for ForwardedProtocol {
155 type Err = InvalidProtocolStr;
156
157 fn from_str(s: &str) -> Result<Self, Self::Err> {
158 s.try_into()
159 }
160}
161
162impl PartialEq<str> for ForwardedProtocol {
163 fn eq(&self, other: &str) -> bool {
164 match &self.0 {
165 ProtocolKind::Https => other.eq_ignore_ascii_case(HTTPS_STR),
166 ProtocolKind::Http => other.eq_ignore_ascii_case(HTTP_STR) || other.is_empty(),
167 }
168 }
169}
170
171impl PartialEq<String> for ForwardedProtocol {
172 fn eq(&self, other: &String) -> bool {
173 self == other.as_str()
174 }
175}
176
177impl PartialEq<&str> for ForwardedProtocol {
178 fn eq(&self, other: &&str) -> bool {
179 self == *other
180 }
181}
182
183impl PartialEq<ForwardedProtocol> for str {
184 fn eq(&self, other: &ForwardedProtocol) -> bool {
185 other == self
186 }
187}
188
189impl PartialEq<ForwardedProtocol> for String {
190 fn eq(&self, other: &ForwardedProtocol) -> bool {
191 other == self.as_str()
192 }
193}
194
195impl PartialEq<ForwardedProtocol> for &str {
196 fn eq(&self, other: &ForwardedProtocol) -> bool {
197 other == *self
198 }
199}
200
201impl std::fmt::Display for ForwardedProtocol {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 self.as_scheme().fmt(f)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_protocol_from_str() {
213 assert_eq!("http".parse(), Ok(ForwardedProtocol::HTTP));
214 assert_eq!("https".parse(), Ok(ForwardedProtocol::HTTPS));
215 }
216
217 #[test]
218 fn test_protocol_secure() {
219 assert!(!ForwardedProtocol::HTTP.is_secure());
220 assert!(ForwardedProtocol::HTTPS.is_secure());
221 }
222}