1use super::error::{Error, Result};
30use super::stats::{HttpStat, ALPN_HTTP1, ALPN_HTTP2};
31use super::SkipVerifier;
32use bytes::Bytes;
33use hickory_resolver::config::LookupIpStrategy;
34use hickory_resolver::name_server::TokioConnectionProvider;
35use hickory_resolver::TokioResolver;
36use http::HeaderValue;
37use http::Response;
38use http::Uri;
39use http::{HeaderMap, Method};
40use http_body_util::{BodyExt, Full};
41use hyper::body::Incoming;
42use hyper::Request;
43use hyper_util::rt::TokioExecutor;
44use hyper_util::rt::TokioIo;
45use std::collections::HashMap;
46use std::net::IpAddr;
47use std::net::SocketAddr;
48use std::sync::Arc;
49use std::sync::Once;
50use std::time::Duration;
51use std::time::Instant;
52use tokio::fs;
53use tokio::net::TcpStream;
54use tokio::sync::oneshot;
55use tokio::time::timeout;
56use tokio_rustls::client::TlsStream;
57use tokio_rustls::rustls::{ClientConfig, RootCertStore};
58use tokio_rustls::TlsConnector;
59
60#[derive(Default, Debug, Clone)]
61pub struct HttpRequest {
62 pub uri: Uri,
63 pub method: Option<Method>,
64 pub alpn_protocols: Vec<String>,
65 pub resolves: Option<HashMap<String, IpAddr>>,
66 pub headers: Option<HeaderMap<HeaderValue>>,
67 pub ip_version: Option<i32>, pub skip_verify: bool,
69 pub output: Option<String>,
70 pub body: Option<Bytes>,
71}
72
73impl TryFrom<&str> for HttpRequest {
74 type Error = Error;
75
76 fn try_from(url: &str) -> Result<Self> {
77 let uri = url.parse::<Uri>().map_err(|e| Error::Uri { source: e })?;
78 Ok(Self {
79 uri,
80 method: None,
81 alpn_protocols: vec![ALPN_HTTP2.to_string(), ALPN_HTTP1.to_string()],
82 resolves: None,
83 headers: None,
84 ip_version: None,
85 skip_verify: false,
86 output: None,
87 body: None,
88 })
89 }
90}
91
92static INIT: Once = Once::new();
93
94fn ensure_crypto_provider() {
95 INIT.call_once(|| {
96 let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
97 });
98}
99
100async fn dns_resolve(req: &HttpRequest, stat: &mut HttpStat) -> Result<(SocketAddr, String)> {
101 let host = req
102 .uri
103 .host()
104 .ok_or(Error::Common {
105 category: "http".to_string(),
106 message: "host is required".to_string(),
107 })?
108 .to_string();
109 let default_port = if req.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
110 443
111 } else {
112 80
113 };
114 let port = req.uri.port_u16().unwrap_or(default_port);
115
116 if let Some(resolves) = &req.resolves {
118 let host_port = format!("{}:{}", host, port);
119 if let Some(ip) = resolves.get(&host_port) {
120 let addr = SocketAddr::new(*ip, port);
121 stat.addr = Some(addr.to_string());
122 return Ok((addr, host));
123 }
124 }
125
126 let provider = TokioConnectionProvider::default();
127 let mut builder = TokioResolver::builder(provider).map_err(|e| Error::Resolve { source: e })?;
128 if let Some(ip_version) = req.ip_version {
129 match ip_version {
130 4 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv4Only,
131 6 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv6Only,
132 _ => {}
133 }
134 }
135
136 let resolver = builder.build();
137 let dns_start = Instant::now();
138 let addr = resolver
139 .lookup_ip(&host)
140 .await
141 .map_err(|e| Error::Resolve { source: e })?;
142 stat.dns_lookup = Some(dns_start.elapsed());
143 let addr = addr.into_iter().next().ok_or(Error::Common {
144 category: "http".to_string(),
145 message: "dns lookup failed".to_string(),
146 })?;
147 let addr = SocketAddr::new(addr, port);
148 stat.addr = Some(addr.to_string());
149
150 Ok((addr, host))
151}
152
153async fn tcp_connect(addr: SocketAddr, stat: &mut HttpStat) -> Result<TcpStream> {
154 let tcp_start = Instant::now();
155 let tcp_stream = timeout(Duration::from_secs(10), TcpStream::connect(addr))
156 .await
157 .map_err(|e| Error::Timeout { source: e })?
158 .map_err(|e| Error::Io { source: e })?;
159 stat.tcp_connect = Some(tcp_start.elapsed());
160 Ok(tcp_stream)
161}
162
163async fn tls_handshake(
164 host: String,
165 tcp_stream: TcpStream,
166 alpn_protocols: Vec<String>,
167 skip_verify: bool,
168 stat: &mut HttpStat,
169) -> Result<(TlsStream<TcpStream>, bool)> {
170 let tls_start = Instant::now();
171 let mut root_store = RootCertStore::empty();
172 let certs = rustls_native_certs::load_native_certs().certs;
173
174 for cert in certs {
175 root_store
176 .add(cert)
177 .map_err(|e| Error::Rustls { source: e })?;
178 }
179 let mut config = ClientConfig::builder()
180 .with_root_certificates(root_store)
181 .with_no_client_auth();
182
183 if skip_verify {
185 config
186 .dangerous()
187 .set_certificate_verifier(Arc::new(SkipVerifier));
188 }
189
190 config.alpn_protocols = alpn_protocols
191 .iter()
192 .map(|s| s.as_bytes().to_vec())
193 .collect();
194
195 let connector = TlsConnector::from(Arc::new(config));
196
197 let tls_stream = connector
199 .connect(
200 host.clone()
201 .try_into()
202 .map_err(|e| Error::InvalidDnsName { source: e })?,
203 tcp_stream,
204 )
205 .await
206 .map_err(|e| Error::Io { source: e })?;
207 stat.tls_handshake = Some(tls_start.elapsed());
208
209 let (_, session) = tls_stream.get_ref();
210
211 stat.tls = session
212 .protocol_version()
213 .map(|v| v.as_str().unwrap_or_default().to_string());
214
215 if let Some(certs) = session.peer_certificates() {
216 if let Some(cert) = certs.first() {
217 if let Ok((_, cert)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
218 stat.cert_not_before = Some(cert.validity().not_before.to_string());
219 stat.cert_not_after = Some(cert.validity().not_after.to_string());
220 if let Ok(Some(sans)) = cert.subject_alternative_name() {
221 let mut domains = Vec::new();
222 for san in sans.value.general_names.iter() {
223 if let x509_parser::extensions::GeneralName::DNSName(domain) = san {
224 domains.push(domain.to_string());
225 }
226 }
227 stat.cert_domains = Some(domains);
228 };
229 }
230 }
231 }
232 if let Some(cipher) = session.negotiated_cipher_suite() {
233 stat.cert_cipher = Some(format!("{:?}", cipher));
234 }
235 let mut is_http2 = false;
236 if let Some(protocol) = session.alpn_protocol() {
237 let alpn = String::from_utf8_lossy(protocol).to_string();
238 is_http2 = alpn == ALPN_HTTP2;
239 stat.alpn = Some(alpn);
240 }
241 Ok((tls_stream, is_http2))
242}
243
244async fn send_http_request(
245 req: Request<Full<Bytes>>,
246 tcp_stream: TcpStream,
247 tx: oneshot::Sender<String>,
248 stat: &mut HttpStat,
249) -> Result<Response<Incoming>> {
250 let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(tcp_stream))
251 .await
252 .map_err(|e| Error::Hyper { source: e })?;
253 tokio::spawn(async move {
255 if let Err(e) = conn.await {
256 let _ = tx.send(e.to_string());
257 }
258 });
259
260 let server_processing_start = Instant::now();
261 let resp = sender
262 .send_request(req)
263 .await
264 .map_err(|e| Error::Hyper { source: e })?;
265 stat.server_processing = Some(server_processing_start.elapsed());
266 Ok(resp)
267}
268
269async fn send_https_request(
270 req: Request<Full<Bytes>>,
271 tls_stream: TlsStream<TcpStream>,
272 tx: oneshot::Sender<String>,
273 stat: &mut HttpStat,
274) -> Result<Response<Incoming>> {
275 let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream))
276 .await
277 .map_err(|e| Error::Hyper { source: e })?;
278 tokio::spawn(async move {
280 if let Err(e) = conn.await {
281 let _ = tx.send(e.to_string());
282 }
283 });
284
285 let server_processing_start = Instant::now();
286 let resp = sender
287 .send_request(req)
288 .await
289 .map_err(|e| Error::Hyper { source: e })?;
290 stat.server_processing = Some(server_processing_start.elapsed());
291 Ok(resp)
292}
293
294async fn send_https2_request(
295 req: Request<Full<Bytes>>,
296 tls_stream: TlsStream<TcpStream>,
297 tx: oneshot::Sender<String>,
298 stat: &mut HttpStat,
299) -> Result<Response<Incoming>> {
300 let (mut sender, conn) =
301 hyper::client::conn::http2::handshake(TokioExecutor::new(), TokioIo::new(tls_stream))
302 .await
303 .map_err(|e| Error::Hyper { source: e })?;
304 tokio::spawn(async move {
306 if let Err(e) = conn.await {
307 let _ = tx.send(e.to_string());
308 }
309 });
310
311 let mut req = req;
312 *req.version_mut() = hyper::Version::HTTP_2;
313 req.headers_mut().remove("Host");
315
316 let server_processing_start = Instant::now();
317 let resp = sender
318 .send_request(req)
319 .await
320 .map_err(|e| Error::Hyper { source: e })?;
321 stat.server_processing = Some(server_processing_start.elapsed());
322 Ok(resp)
323}
324
325fn finish_with_error(mut stat: HttpStat, error: impl ToString, start: Instant) -> HttpStat {
326 stat.error = Some(error.to_string());
327 stat.total = Some(start.elapsed());
328 stat
329}
330
331pub async fn request(http_req: HttpRequest) -> HttpStat {
332 ensure_crypto_provider();
333 let start = Instant::now();
334 let mut stat = HttpStat::default();
335
336 let dns_result = dns_resolve(&http_req, &mut stat).await;
338 let (addr, host) = match dns_result {
339 Ok(result) => result,
340 Err(e) => {
341 return finish_with_error(stat, e, start);
342 }
343 };
344
345 let tcp_stream = match tcp_connect(addr, &mut stat).await {
347 Ok(stream) => stream,
348 Err(e) => {
349 return finish_with_error(stat, e, start);
350 }
351 };
352
353 let uri = http_req.uri;
354 let is_https = uri.scheme() == Some(&http::uri::Scheme::HTTPS);
355 let mut builder = Request::builder()
356 .uri(&uri)
357 .method(http_req.method.unwrap_or(Method::GET));
358 let mut set_host = false;
359 if let Some(headers) = http_req.headers {
360 for (key, value) in headers.iter() {
361 builder = builder.header(key, value);
362 if key.to_string().to_lowercase() == "host" {
363 set_host = true;
364 }
365 }
366 }
367 if !set_host {
368 builder = builder.header("Host", host.clone());
369 }
370
371 let req = match builder.body(Full::new(http_req.body.unwrap_or_default())) {
373 Ok(req) => req,
374 Err(e) => {
375 return finish_with_error(stat, format!("Failed to build request: {}", e), start);
376 }
377 };
378
379 let (tx, mut rx) = oneshot::channel();
381
382 let resp = if is_https {
384 let tls_result = tls_handshake(
386 host.clone(),
387 tcp_stream,
388 http_req.alpn_protocols,
389 http_req.skip_verify,
390 &mut stat,
391 )
392 .await;
393 let (tls_stream, is_http2) = match tls_result {
394 Ok(result) => result,
395 Err(e) => {
396 return finish_with_error(stat, e, start);
397 }
398 };
399
400 if is_http2 {
402 match send_https2_request(req, tls_stream, tx, &mut stat).await {
403 Ok(resp) => resp,
404 Err(e) => {
405 return finish_with_error(stat, e, start);
406 }
407 }
408 } else {
409 match send_https_request(req, tls_stream, tx, &mut stat).await {
410 Ok(resp) => resp,
411 Err(e) => {
412 return finish_with_error(stat, e, start);
413 }
414 }
415 }
416 } else {
417 match send_http_request(req, tcp_stream, tx, &mut stat).await {
419 Ok(resp) => resp,
420 Err(e) => {
421 return finish_with_error(stat, e, start);
422 }
423 }
424 };
425
426 stat.status = Some(resp.status());
427 stat.headers = Some(resp.headers().clone());
428
429 let content_transfer_start = Instant::now();
431 let body_result = resp.collect().await;
432 let body = match body_result {
433 Ok(body) => body,
434 Err(e) => {
435 return finish_with_error(stat, format!("Failed to read response body: {}", e), start);
436 }
437 };
438
439 let body_bytes = body.to_bytes();
440 if let Some(output) = http_req.output {
441 match fs::write(output, body_bytes).await {
442 Ok(_) => {}
443 Err(e) => {
444 return finish_with_error(
445 stat,
446 format!("Failed to write response body to file: {}", e),
447 start,
448 );
449 }
450 }
451 } else {
452 stat.body = Some(body_bytes);
453 }
454 stat.content_transfer = Some(content_transfer_start.elapsed());
455
456 if let Ok(error) = rx.try_recv() {
458 stat.error = Some(error);
459 }
460
461 stat.total = Some(start.elapsed());
462 stat
463}