monoio_http_client/client/
key.rs1use std::{convert::Infallible, hash::Hash, net::ToSocketAddrs};
2
3use http::{Uri, Version};
4use service_async::{Param, ParamMut, ParamRef};
5use smol_str::SmolStr;
6use thiserror::Error as ThisError;
7
8#[cfg(feature = "native-tls")]
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ServerName(pub SmolStr);
11#[cfg(not(feature = "native-tls"))]
12pub use rustls::ServerName;
13
14use super::unified::UnifiedTransportAddr;
15
16#[cfg(feature = "native-tls")]
17impl<T: Into<SmolStr>> From<T> for ServerName {
18 fn from(value: T) -> Self {
19 Self(value.into())
20 }
21}
22
23pub struct Key {
24 pub host: SmolStr,
25 pub port: u16,
26 pub version: Version,
27
28 pub server_name: Option<ServerName>,
29}
30
31pub trait HttpVersion {
32 fn get_version(&self) -> Version;
33}
34
35impl HttpVersion for Key {
36 fn get_version(&self) -> Version {
37 self.version
38 }
39}
40
41impl Clone for Key {
42 fn clone(&self) -> Self {
43 Self {
44 host: self.host.clone(),
45 port: self.port,
46 server_name: self.server_name.clone(),
47 version: self.version,
48 }
49 }
50}
51
52impl std::fmt::Debug for Key {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}:{}:{:?}", self.host, self.port, self.version)
55 }
56}
57
58impl std::fmt::Display for Key {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}:{}:{:?}", self.host, self.port, self.version)
61 }
62}
63
64impl PartialEq for Key {
65 fn eq(&self, other: &Self) -> bool {
66 self.host == other.host && self.port == other.port && self.version == other.version
67 }
68}
69
70impl Eq for Key {}
71
72impl Hash for Key {
73 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74 self.host.hash(state);
75 self.port.hash(state);
76 self.version.hash(state);
77 }
78}
79
80impl ToSocketAddrs for Key {
81 type Iter = <(&'static str, u16) as ToSocketAddrs>::Iter;
82
83 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
84 (self.host.as_str(), self.port).to_socket_addrs()
85 }
86}
87
88impl Param<Option<ServerName>> for Key {
89 fn param(&self) -> Option<ServerName> {
90 self.server_name.clone()
91 }
92}
93
94impl ParamRef<Option<ServerName>> for Key {
95 fn param_ref(&self) -> &Option<ServerName> {
96 &self.server_name
97 }
98}
99
100impl ParamMut<Option<ServerName>> for Key {
101 fn param_mut(&mut self) -> &mut Option<ServerName> {
102 &mut self.server_name
103 }
104}
105
106impl Param<UnifiedTransportAddr> for Key {
107 fn param(&self) -> UnifiedTransportAddr {
108 if let Some(sn) = self.server_name.clone() {
109 UnifiedTransportAddr::TcpTls(self.host.clone(), self.port, sn)
110 } else {
111 UnifiedTransportAddr::Tcp(self.host.clone(), self.port)
112 }
113 }
114}
115
116#[derive(ThisError, Debug)]
117pub enum FromUriError {
118 #[error("Invalid dns name")]
119 InvalidDnsName(#[from] rustls::client::InvalidDnsNameError),
120 #[error("Scheme not supported")]
121 UnsupportScheme,
122 #[error("Missing authority in uri")]
123 NoAuthority,
124}
125
126impl From<Infallible> for FromUriError {
127 fn from(_: Infallible) -> Self {
128 unsafe { std::hint::unreachable_unchecked() }
129 }
130}
131
132impl TryFrom<&Uri> for Key {
133 type Error = FromUriError;
134
135 fn try_from(uri: &Uri) -> Result<Self, Self::Error> {
136 let (tls, default_port) = match uri.scheme() {
137 Some(scheme) if scheme == &http::uri::Scheme::HTTP => (false, 80),
138 Some(scheme) if scheme == &http::uri::Scheme::HTTPS => (true, 443),
139 _ => (false, 0),
140 };
141 let port = uri.port_u16().unwrap_or(default_port);
142
143 let host = match uri.host() {
144 Some(a) => {
145 if a.starts_with('[') && a.ends_with(']') {
146 &a[1..a.len() - 1]
147 } else {
148 a
149 }
150 }
151 None => return Err(FromUriError::NoAuthority),
152 };
153 let sni = if tls { Some(host) } else { None };
154 (host, sni, port).try_into().map_err(Into::into)
155 }
156}
157
158impl TryFrom<Uri> for Key {
159 type Error = FromUriError;
160
161 fn try_from(value: Uri) -> Result<Self, Self::Error> {
162 Self::try_from(&value)
163 }
164}
165
166impl TryFrom<(&str, Option<&str>, u16)> for Key {
168 #[cfg(not(feature = "native-tls"))]
169 type Error = rustls::client::InvalidDnsNameError;
170 #[cfg(feature = "native-tls")]
171 type Error = std::convert::Infallible;
172
173 fn try_from((host, server_name, port): (&str, Option<&str>, u16)) -> Result<Self, Self::Error> {
174 let server_name = match server_name {
175 Some(s) => s,
176 None => {
177 return Ok(Self {
178 host: host.into(),
179 port,
180 version: http::version::Version::HTTP_11,
181 server_name: None,
182 })
183 }
184 };
185
186 #[cfg(not(feature = "native-tls"))]
187 let server_name = Some(ServerName::try_from(server_name)?);
188 #[cfg(feature = "native-tls")]
189 let server_name = Some(ServerName(server_name.into()));
190
191 Ok(Self {
192 host: host.into(),
193 port,
194 version: http::version::Version::HTTP_11,
195 server_name,
196 })
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn key_default_port() {
206 let key: Key = ("bytedance.com", Some("bytedance.com"), 80)
207 .try_into()
208 .expect("unable to convert to Key");
209 assert_eq!(key.port, 80);
210 assert_eq!(key.host, "bytedance.com");
211 #[cfg(feature = "rustls")]
212 assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
213 #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
214 assert_eq!(key.server_name, Some("bytedance.com".into()));
215 }
216
217 #[test]
218 fn key_specify_port() {
219 let uri = Uri::try_from("https://bytedance.com:12345").unwrap();
220 let key: Key = uri.try_into().expect("unable to convert to Key");
221 assert_eq!(key.port, 12345);
222 assert_eq!(key.host, "bytedance.com");
223 #[cfg(feature = "rustls")]
224 assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
225 #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
226 assert_eq!(key.server_name, Some("bytedance.com".into()));
227 }
228
229 #[test]
230 fn key_ip_http() {
231 let uri = Uri::try_from("http://1.1.1.1:443").unwrap();
232 let key: Key = uri.try_into().expect("unable to convert to Key");
233 assert_eq!(key.port, 443);
234 assert_eq!(key.host, "1.1.1.1");
235 assert_eq!(key.server_name, None);
236 }
237
238 #[test]
239 fn key_uri() {
240 let uri = Uri::try_from("https://bytedance.com").unwrap();
241 let key: Key = (&uri).try_into().expect("unable to convert to Key");
242 assert_eq!(key.port, 443);
243 assert_eq!(key.host, "bytedance.com");
244 #[cfg(feature = "rustls")]
245 assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
246 #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
247 assert_eq!(key.server_name, Some("bytedance.com".into()));
248 }
249}