monoio_http_client/client/
key.rs

1use std::{convert::Infallible, hash::Hash, net::ToSocketAddrs};
2
3use http::{Uri, Version};
4use service_async::{Param, ParamMut, ParamRef};
5use smol_str::SmolStr;
6use thiserror::Error as ThisError;
7
8#[cfg(feature = "native-tls")]
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ServerName(pub SmolStr);
11#[cfg(not(feature = "native-tls"))]
12pub use rustls::ServerName;
13
14use super::unified::UnifiedTransportAddr;
15
16#[cfg(feature = "native-tls")]
17impl<T: Into<SmolStr>> From<T> for ServerName {
18    fn from(value: T) -> Self {
19        Self(value.into())
20    }
21}
22
23pub struct Key {
24    pub host: SmolStr,
25    pub port: u16,
26    pub version: Version,
27
28    pub server_name: Option<ServerName>,
29}
30
31pub trait HttpVersion {
32    fn get_version(&self) -> Version;
33}
34
35impl HttpVersion for Key {
36    fn get_version(&self) -> Version {
37        self.version
38    }
39}
40
41impl Clone for Key {
42    fn clone(&self) -> Self {
43        Self {
44            host: self.host.clone(),
45            port: self.port,
46            server_name: self.server_name.clone(),
47            version: self.version,
48        }
49    }
50}
51
52impl std::fmt::Debug for Key {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "{}:{}:{:?}", self.host, self.port, self.version)
55    }
56}
57
58impl std::fmt::Display for Key {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}:{}:{:?}", self.host, self.port, self.version)
61    }
62}
63
64impl PartialEq for Key {
65    fn eq(&self, other: &Self) -> bool {
66        self.host == other.host && self.port == other.port && self.version == other.version
67    }
68}
69
70impl Eq for Key {}
71
72impl Hash for Key {
73    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74        self.host.hash(state);
75        self.port.hash(state);
76        self.version.hash(state);
77    }
78}
79
80impl ToSocketAddrs for Key {
81    type Iter = <(&'static str, u16) as ToSocketAddrs>::Iter;
82
83    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
84        (self.host.as_str(), self.port).to_socket_addrs()
85    }
86}
87
88impl Param<Option<ServerName>> for Key {
89    fn param(&self) -> Option<ServerName> {
90        self.server_name.clone()
91    }
92}
93
94impl ParamRef<Option<ServerName>> for Key {
95    fn param_ref(&self) -> &Option<ServerName> {
96        &self.server_name
97    }
98}
99
100impl ParamMut<Option<ServerName>> for Key {
101    fn param_mut(&mut self) -> &mut Option<ServerName> {
102        &mut self.server_name
103    }
104}
105
106impl Param<UnifiedTransportAddr> for Key {
107    fn param(&self) -> UnifiedTransportAddr {
108        if let Some(sn) = self.server_name.clone() {
109            UnifiedTransportAddr::TcpTls(self.host.clone(), self.port, sn)
110        } else {
111            UnifiedTransportAddr::Tcp(self.host.clone(), self.port)
112        }
113    }
114}
115
116#[derive(ThisError, Debug)]
117pub enum FromUriError {
118    #[error("Invalid dns name")]
119    InvalidDnsName(#[from] rustls::client::InvalidDnsNameError),
120    #[error("Scheme not supported")]
121    UnsupportScheme,
122    #[error("Missing authority in uri")]
123    NoAuthority,
124}
125
126impl From<Infallible> for FromUriError {
127    fn from(_: Infallible) -> Self {
128        unsafe { std::hint::unreachable_unchecked() }
129    }
130}
131
132impl TryFrom<&Uri> for Key {
133    type Error = FromUriError;
134
135    fn try_from(uri: &Uri) -> Result<Self, Self::Error> {
136        let (tls, default_port) = match uri.scheme() {
137            Some(scheme) if scheme == &http::uri::Scheme::HTTP => (false, 80),
138            Some(scheme) if scheme == &http::uri::Scheme::HTTPS => (true, 443),
139            _ => (false, 0),
140        };
141        let port = uri.port_u16().unwrap_or(default_port);
142
143        let host = match uri.host() {
144            Some(a) => {
145                if a.starts_with('[') && a.ends_with(']') {
146                    &a[1..a.len() - 1]
147                } else {
148                    a
149                }
150            }
151            None => return Err(FromUriError::NoAuthority),
152        };
153        let sni = if tls { Some(host) } else { None };
154        (host, sni, port).try_into().map_err(Into::into)
155    }
156}
157
158impl TryFrom<Uri> for Key {
159    type Error = FromUriError;
160
161    fn try_from(value: Uri) -> Result<Self, Self::Error> {
162        Self::try_from(&value)
163    }
164}
165
166// host, sni, port
167impl TryFrom<(&str, Option<&str>, u16)> for Key {
168    #[cfg(not(feature = "native-tls"))]
169    type Error = rustls::client::InvalidDnsNameError;
170    #[cfg(feature = "native-tls")]
171    type Error = std::convert::Infallible;
172
173    fn try_from((host, server_name, port): (&str, Option<&str>, u16)) -> Result<Self, Self::Error> {
174        let server_name = match server_name {
175            Some(s) => s,
176            None => {
177                return Ok(Self {
178                    host: host.into(),
179                    port,
180                    version: http::version::Version::HTTP_11,
181                    server_name: None,
182                })
183            }
184        };
185
186        #[cfg(not(feature = "native-tls"))]
187        let server_name = Some(ServerName::try_from(server_name)?);
188        #[cfg(feature = "native-tls")]
189        let server_name = Some(ServerName(server_name.into()));
190
191        Ok(Self {
192            host: host.into(),
193            port,
194            version: http::version::Version::HTTP_11,
195            server_name,
196        })
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn key_default_port() {
206        let key: Key = ("bytedance.com", Some("bytedance.com"), 80)
207            .try_into()
208            .expect("unable to convert to Key");
209        assert_eq!(key.port, 80);
210        assert_eq!(key.host, "bytedance.com");
211        #[cfg(feature = "rustls")]
212        assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
213        #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
214        assert_eq!(key.server_name, Some("bytedance.com".into()));
215    }
216
217    #[test]
218    fn key_specify_port() {
219        let uri = Uri::try_from("https://bytedance.com:12345").unwrap();
220        let key: Key = uri.try_into().expect("unable to convert to Key");
221        assert_eq!(key.port, 12345);
222        assert_eq!(key.host, "bytedance.com");
223        #[cfg(feature = "rustls")]
224        assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
225        #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
226        assert_eq!(key.server_name, Some("bytedance.com".into()));
227    }
228
229    #[test]
230    fn key_ip_http() {
231        let uri = Uri::try_from("http://1.1.1.1:443").unwrap();
232        let key: Key = uri.try_into().expect("unable to convert to Key");
233        assert_eq!(key.port, 443);
234        assert_eq!(key.host, "1.1.1.1");
235        assert_eq!(key.server_name, None);
236    }
237
238    #[test]
239    fn key_uri() {
240        let uri = Uri::try_from("https://bytedance.com").unwrap();
241        let key: Key = (&uri).try_into().expect("unable to convert to Key");
242        assert_eq!(key.port, 443);
243        assert_eq!(key.host, "bytedance.com");
244        #[cfg(feature = "rustls")]
245        assert_eq!(key.server_name, Some("bytedance.com".try_into().unwrap()));
246        #[cfg(all(feature = "native-tls", not(feature = "rustls")))]
247        assert_eq!(key.server_name, Some("bytedance.com".into()));
248    }
249}