1use super::error::{Error, Result};
30use super::stats::{HttpStat, ALPN_HTTP1, ALPN_HTTP2};
31use super::SkipVerifier;
32use bytes::Bytes;
33use hickory_resolver::config::LookupIpStrategy;
34use hickory_resolver::name_server::TokioConnectionProvider;
35use hickory_resolver::TokioResolver;
36use http::HeaderValue;
37use http::Response;
38use http::Uri;
39use http::{HeaderMap, Method};
40use http_body_util::{BodyExt, Full};
41use hyper::body::Incoming;
42use hyper::Request;
43use hyper_util::rt::TokioExecutor;
44use hyper_util::rt::TokioIo;
45use std::collections::HashMap;
46use std::net::IpAddr;
47use std::net::SocketAddr;
48use std::sync::Arc;
49use std::sync::Once;
50use std::time::Duration;
51use std::time::Instant;
52use tokio::fs;
53use tokio::net::TcpStream;
54use tokio::sync::oneshot;
55use tokio::time::timeout;
56use tokio_rustls::client::TlsStream;
57use tokio_rustls::rustls::{ClientConfig, RootCertStore};
58use tokio_rustls::TlsConnector;
59
60const VERSION: &str = env!("CARGO_PKG_VERSION");
61
62#[derive(Default, Debug, Clone)]
63pub struct HttpRequest {
64 pub uri: Uri,
65 pub method: Option<Method>,
66 pub alpn_protocols: Vec<String>,
67 pub resolves: Option<HashMap<String, IpAddr>>,
68 pub headers: Option<HeaderMap<HeaderValue>>,
69 pub ip_version: Option<i32>, pub skip_verify: bool,
71 pub output: Option<String>,
72 pub body: Option<Bytes>,
73}
74
75impl TryFrom<&str> for HttpRequest {
76 type Error = Error;
77
78 fn try_from(url: &str) -> Result<Self> {
79 let uri = url.parse::<Uri>().map_err(|e| Error::Uri { source: e })?;
80 Ok(Self {
81 uri,
82 method: None,
83 alpn_protocols: vec![ALPN_HTTP2.to_string(), ALPN_HTTP1.to_string()],
84 resolves: None,
85 headers: None,
86 ip_version: None,
87 skip_verify: false,
88 output: None,
89 body: None,
90 })
91 }
92}
93
94static INIT: Once = Once::new();
95
96fn ensure_crypto_provider() {
97 INIT.call_once(|| {
98 let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
99 });
100}
101
102async fn dns_resolve(req: &HttpRequest, stat: &mut HttpStat) -> Result<(SocketAddr, String)> {
103 let host = req
104 .uri
105 .host()
106 .ok_or(Error::Common {
107 category: "http".to_string(),
108 message: "host is required".to_string(),
109 })?
110 .to_string();
111 let default_port = if req.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
112 443
113 } else {
114 80
115 };
116 let port = req.uri.port_u16().unwrap_or(default_port);
117
118 if let Some(resolves) = &req.resolves {
120 let host_port = format!("{}:{}", host, port);
121 if let Some(ip) = resolves.get(&host_port) {
122 let addr = SocketAddr::new(*ip, port);
123 stat.addr = Some(addr.to_string());
124 return Ok((addr, host));
125 }
126 }
127
128 let provider = TokioConnectionProvider::default();
129 let mut builder = TokioResolver::builder(provider).map_err(|e| Error::Resolve { source: e })?;
130 if let Some(ip_version) = req.ip_version {
131 match ip_version {
132 4 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv4Only,
133 6 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv6Only,
134 _ => {}
135 }
136 }
137
138 let resolver = builder.build();
139 let dns_start = Instant::now();
140 let addr = resolver
141 .lookup_ip(&host)
142 .await
143 .map_err(|e| Error::Resolve { source: e })?;
144 stat.dns_lookup = Some(dns_start.elapsed());
145 let addr = addr.into_iter().next().ok_or(Error::Common {
146 category: "http".to_string(),
147 message: "dns lookup failed".to_string(),
148 })?;
149 let addr = SocketAddr::new(addr, port);
150 stat.addr = Some(addr.to_string());
151
152 Ok((addr, host))
153}
154
155async fn tcp_connect(addr: SocketAddr, stat: &mut HttpStat) -> Result<TcpStream> {
156 let tcp_start = Instant::now();
157 let tcp_stream = timeout(Duration::from_secs(10), TcpStream::connect(addr))
158 .await
159 .map_err(|e| Error::Timeout { source: e })?
160 .map_err(|e| Error::Io { source: e })?;
161 stat.tcp_connect = Some(tcp_start.elapsed());
162 Ok(tcp_stream)
163}
164
165async fn tls_handshake(
166 host: String,
167 tcp_stream: TcpStream,
168 alpn_protocols: Vec<String>,
169 skip_verify: bool,
170 stat: &mut HttpStat,
171) -> Result<(TlsStream<TcpStream>, bool)> {
172 let tls_start = Instant::now();
173 let mut root_store = RootCertStore::empty();
174 let certs = rustls_native_certs::load_native_certs().certs;
175
176 for cert in certs {
177 root_store
178 .add(cert)
179 .map_err(|e| Error::Rustls { source: e })?;
180 }
181 let mut config = ClientConfig::builder()
182 .with_root_certificates(root_store)
183 .with_no_client_auth();
184
185 if skip_verify {
187 config
188 .dangerous()
189 .set_certificate_verifier(Arc::new(SkipVerifier));
190 }
191
192 config.alpn_protocols = alpn_protocols
193 .iter()
194 .map(|s| s.as_bytes().to_vec())
195 .collect();
196
197 let connector = TlsConnector::from(Arc::new(config));
198
199 let tls_stream = timeout(
201 Duration::from_secs(30),
202 connector.connect(
203 host.clone()
204 .try_into()
205 .map_err(|e| Error::InvalidDnsName { source: e })?,
206 tcp_stream,
207 ),
208 )
209 .await
210 .map_err(|e| Error::Timeout { source: e })?
211 .map_err(|e| Error::Io { source: e })?;
212 stat.tls_handshake = Some(tls_start.elapsed());
213
214 let (_, session) = tls_stream.get_ref();
215
216 stat.tls = session
217 .protocol_version()
218 .map(|v| v.as_str().unwrap_or_default().to_string());
219
220 if let Some(certs) = session.peer_certificates() {
221 if let Some(cert) = certs.first() {
222 if let Ok((_, cert)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
223 stat.cert_not_before = Some(cert.validity().not_before.to_string());
224 stat.cert_not_after = Some(cert.validity().not_after.to_string());
225 if let Ok(Some(sans)) = cert.subject_alternative_name() {
226 let mut domains = Vec::new();
227 for san in sans.value.general_names.iter() {
228 if let x509_parser::extensions::GeneralName::DNSName(domain) = san {
229 domains.push(domain.to_string());
230 }
231 }
232 stat.cert_domains = Some(domains);
233 };
234 }
235 }
236 }
237 if let Some(cipher) = session.negotiated_cipher_suite() {
238 stat.cert_cipher = Some(format!("{:?}", cipher));
239 }
240 let mut is_http2 = false;
241 if let Some(protocol) = session.alpn_protocol() {
242 let alpn = String::from_utf8_lossy(protocol).to_string();
243 is_http2 = alpn == ALPN_HTTP2;
244 stat.alpn = Some(alpn);
245 }
246 Ok((tls_stream, is_http2))
247}
248
249async fn send_http_request(
250 req: Request<Full<Bytes>>,
251 tcp_stream: TcpStream,
252 tx: oneshot::Sender<String>,
253 stat: &mut HttpStat,
254) -> Result<Response<Incoming>> {
255 let (mut sender, conn) = timeout(
256 Duration::from_secs(30),
257 hyper::client::conn::http1::handshake(TokioIo::new(tcp_stream)),
258 )
259 .await
260 .map_err(|e| Error::Timeout { source: e })?
261 .map_err(|e| Error::Hyper { source: e })?;
262 tokio::spawn(async move {
264 if let Err(e) = conn.await {
265 let _ = tx.send(e.to_string());
266 }
267 });
268
269 let server_processing_start = Instant::now();
270 let resp = sender
271 .send_request(req)
272 .await
273 .map_err(|e| Error::Hyper { source: e })?;
274 stat.server_processing = Some(server_processing_start.elapsed());
275 Ok(resp)
276}
277
278async fn send_https_request(
279 req: Request<Full<Bytes>>,
280 tls_stream: TlsStream<TcpStream>,
281 tx: oneshot::Sender<String>,
282 stat: &mut HttpStat,
283) -> Result<Response<Incoming>> {
284 let (mut sender, conn) = timeout(
285 Duration::from_secs(30),
286 hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)),
287 )
288 .await
289 .map_err(|e| Error::Timeout { source: e })?
290 .map_err(|e| Error::Hyper { source: e })?;
291 tokio::spawn(async move {
293 if let Err(e) = conn.await {
294 let _ = tx.send(e.to_string());
295 }
296 });
297
298 let server_processing_start = Instant::now();
299 let resp = sender
300 .send_request(req)
301 .await
302 .map_err(|e| Error::Hyper { source: e })?;
303 stat.server_processing = Some(server_processing_start.elapsed());
304 Ok(resp)
305}
306
307async fn send_https2_request(
308 req: Request<Full<Bytes>>,
309 tls_stream: TlsStream<TcpStream>,
310 tx: oneshot::Sender<String>,
311 stat: &mut HttpStat,
312) -> Result<Response<Incoming>> {
313 let (mut sender, conn) = timeout(
314 Duration::from_secs(30),
315 hyper::client::conn::http2::handshake(TokioExecutor::new(), TokioIo::new(tls_stream)),
316 )
317 .await
318 .map_err(|e| Error::Timeout { source: e })?
319 .map_err(|e| Error::Hyper { source: e })?;
320 tokio::spawn(async move {
322 if let Err(e) = conn.await {
323 let _ = tx.send(e.to_string());
324 }
325 });
326
327 let mut req = req;
328 *req.version_mut() = hyper::Version::HTTP_2;
329 req.headers_mut().remove("Host");
331
332 let server_processing_start = Instant::now();
333 let resp = sender
334 .send_request(req)
335 .await
336 .map_err(|e| Error::Hyper { source: e })?;
337 stat.server_processing = Some(server_processing_start.elapsed());
338 Ok(resp)
339}
340
341fn finish_with_error(mut stat: HttpStat, error: impl ToString, start: Instant) -> HttpStat {
342 stat.error = Some(error.to_string());
343 stat.total = Some(start.elapsed());
344 stat
345}
346
347pub async fn request(http_req: HttpRequest) -> HttpStat {
348 ensure_crypto_provider();
349 let start = Instant::now();
350 let mut stat = HttpStat::default();
351
352 let dns_result = dns_resolve(&http_req, &mut stat).await;
354 let (addr, host) = match dns_result {
355 Ok(result) => result,
356 Err(e) => {
357 return finish_with_error(stat, e, start);
358 }
359 };
360
361 let tcp_stream = match tcp_connect(addr, &mut stat).await {
363 Ok(stream) => stream,
364 Err(e) => {
365 return finish_with_error(stat, e, start);
366 }
367 };
368
369 let uri = http_req.uri;
370 let is_https = uri.scheme() == Some(&http::uri::Scheme::HTTPS);
371 let mut builder = Request::builder()
372 .uri(&uri)
373 .method(http_req.method.unwrap_or(Method::GET));
374 let mut set_host = false;
375 let mut set_user_agent = false;
376 if let Some(headers) = http_req.headers {
377 for (key, value) in headers.iter() {
378 builder = builder.header(key, value);
379 match key.to_string().to_lowercase().as_str() {
380 "host" => set_host = true,
381 "user-agent" => set_user_agent = true,
382 _ => {}
383 }
384 }
385 }
386 if !set_host {
387 builder = builder.header("Host", host.clone());
388 }
389 if !set_user_agent {
390 builder = builder.header("User-Agent", format!("httpstat.rs/{}", VERSION));
391 }
392
393 let req = match builder.body(Full::new(http_req.body.unwrap_or_default())) {
395 Ok(req) => req,
396 Err(e) => {
397 return finish_with_error(stat, format!("Failed to build request: {}", e), start);
398 }
399 };
400
401 let (tx, mut rx) = oneshot::channel();
403
404 let resp = if is_https {
406 let tls_result = tls_handshake(
408 host.clone(),
409 tcp_stream,
410 http_req.alpn_protocols,
411 http_req.skip_verify,
412 &mut stat,
413 )
414 .await;
415 let (tls_stream, is_http2) = match tls_result {
416 Ok(result) => result,
417 Err(e) => {
418 return finish_with_error(stat, e, start);
419 }
420 };
421
422 if is_http2 {
424 match send_https2_request(req, tls_stream, tx, &mut stat).await {
425 Ok(resp) => resp,
426 Err(e) => {
427 return finish_with_error(stat, e, start);
428 }
429 }
430 } else {
431 match send_https_request(req, tls_stream, tx, &mut stat).await {
432 Ok(resp) => resp,
433 Err(e) => {
434 return finish_with_error(stat, e, start);
435 }
436 }
437 }
438 } else {
439 match send_http_request(req, tcp_stream, tx, &mut stat).await {
441 Ok(resp) => resp,
442 Err(e) => {
443 return finish_with_error(stat, e, start);
444 }
445 }
446 };
447
448 stat.status = Some(resp.status());
449 stat.headers = Some(resp.headers().clone());
450
451 let content_transfer_start = Instant::now();
453 let body_result = resp.collect().await;
454 let body = match body_result {
455 Ok(body) => body,
456 Err(e) => {
457 return finish_with_error(stat, format!("Failed to read response body: {}", e), start);
458 }
459 };
460
461 let body_bytes = body.to_bytes();
462 if let Some(output) = http_req.output {
463 match fs::write(output, body_bytes).await {
464 Ok(_) => {}
465 Err(e) => {
466 return finish_with_error(
467 stat,
468 format!("Failed to write response body to file: {}", e),
469 start,
470 );
471 }
472 }
473 } else {
474 stat.body = Some(body_bytes);
475 }
476 stat.content_transfer = Some(content_transfer_start.elapsed());
477
478 if let Ok(error) = rx.try_recv() {
480 stat.error = Some(error);
481 }
482
483 stat.total = Some(start.elapsed());
484 stat
485}