http_stat/
request.rs

1// Copyright 2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Copyright 2025 Tree xie.
16//
17// Licensed under the Apache License, Version 2.0 (the "License");
18// you may not use this file except in compliance with the License.
19// You may obtain a copy of the License at
20//
21// http://www.apache.org/licenses/LICENSE-2.0
22//
23// Unless required by applicable law or agreed to in writing, software
24// distributed under the License is distributed on an "AS IS" BASIS,
25// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26// See the License for the specific language governing permissions and
27// limitations under the License.
28
29use super::error::{Error, Result};
30use super::stats::{HttpStat, ALPN_HTTP1, ALPN_HTTP2};
31use super::SkipVerifier;
32use bytes::Bytes;
33use hickory_resolver::config::LookupIpStrategy;
34use hickory_resolver::name_server::TokioConnectionProvider;
35use hickory_resolver::TokioResolver;
36use http::HeaderValue;
37use http::Response;
38use http::Uri;
39use http::{HeaderMap, Method};
40use http_body_util::{BodyExt, Full};
41use hyper::body::Incoming;
42use hyper::Request;
43use hyper_util::rt::TokioExecutor;
44use hyper_util::rt::TokioIo;
45use std::collections::HashMap;
46use std::net::IpAddr;
47use std::net::SocketAddr;
48use std::sync::Arc;
49use std::sync::Once;
50use std::time::Duration;
51use std::time::Instant;
52use tokio::fs;
53use tokio::net::TcpStream;
54use tokio::sync::oneshot;
55use tokio::time::timeout;
56use tokio_rustls::client::TlsStream;
57use tokio_rustls::rustls::{ClientConfig, RootCertStore};
58use tokio_rustls::TlsConnector;
59
60const VERSION: &str = env!("CARGO_PKG_VERSION");
61
62#[derive(Default, Debug, Clone)]
63pub struct HttpRequest {
64    pub uri: Uri,
65    pub method: Option<Method>,
66    pub alpn_protocols: Vec<String>,
67    pub resolves: Option<HashMap<String, IpAddr>>,
68    pub headers: Option<HeaderMap<HeaderValue>>,
69    pub ip_version: Option<i32>, // 4 for IPv4, 6 for IPv6
70    pub skip_verify: bool,
71    pub output: Option<String>,
72    pub body: Option<Bytes>,
73}
74
75impl TryFrom<&str> for HttpRequest {
76    type Error = Error;
77
78    fn try_from(url: &str) -> Result<Self> {
79        let uri = url.parse::<Uri>().map_err(|e| Error::Uri { source: e })?;
80        Ok(Self {
81            uri,
82            method: None,
83            alpn_protocols: vec![ALPN_HTTP2.to_string(), ALPN_HTTP1.to_string()],
84            resolves: None,
85            headers: None,
86            ip_version: None,
87            skip_verify: false,
88            output: None,
89            body: None,
90        })
91    }
92}
93
94static INIT: Once = Once::new();
95
96fn ensure_crypto_provider() {
97    INIT.call_once(|| {
98        let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
99    });
100}
101
102async fn dns_resolve(req: &HttpRequest, stat: &mut HttpStat) -> Result<(SocketAddr, String)> {
103    let host = req
104        .uri
105        .host()
106        .ok_or(Error::Common {
107            category: "http".to_string(),
108            message: "host is required".to_string(),
109        })?
110        .to_string();
111    let default_port = if req.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
112        443
113    } else {
114        80
115    };
116    let port = req.uri.port_u16().unwrap_or(default_port);
117
118    // Check if we have a resolve entry for this host:port
119    if let Some(resolves) = &req.resolves {
120        let host_port = format!("{}:{}", host, port);
121        if let Some(ip) = resolves.get(&host_port) {
122            let addr = SocketAddr::new(*ip, port);
123            stat.addr = Some(addr.to_string());
124            return Ok((addr, host));
125        }
126    }
127
128    let provider = TokioConnectionProvider::default();
129    let mut builder = TokioResolver::builder(provider).map_err(|e| Error::Resolve { source: e })?;
130    if let Some(ip_version) = req.ip_version {
131        match ip_version {
132            4 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv4Only,
133            6 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv6Only,
134            _ => {}
135        }
136    }
137
138    let resolver = builder.build();
139    let dns_start = Instant::now();
140    let addr = resolver
141        .lookup_ip(&host)
142        .await
143        .map_err(|e| Error::Resolve { source: e })?;
144    stat.dns_lookup = Some(dns_start.elapsed());
145    let addr = addr.into_iter().next().ok_or(Error::Common {
146        category: "http".to_string(),
147        message: "dns lookup failed".to_string(),
148    })?;
149    let addr = SocketAddr::new(addr, port);
150    stat.addr = Some(addr.to_string());
151
152    Ok((addr, host))
153}
154
155async fn tcp_connect(addr: SocketAddr, stat: &mut HttpStat) -> Result<TcpStream> {
156    let tcp_start = Instant::now();
157    let tcp_stream = timeout(Duration::from_secs(10), TcpStream::connect(addr))
158        .await
159        .map_err(|e| Error::Timeout { source: e })?
160        .map_err(|e| Error::Io { source: e })?;
161    stat.tcp_connect = Some(tcp_start.elapsed());
162    Ok(tcp_stream)
163}
164
165async fn tls_handshake(
166    host: String,
167    tcp_stream: TcpStream,
168    alpn_protocols: Vec<String>,
169    skip_verify: bool,
170    stat: &mut HttpStat,
171) -> Result<(TlsStream<TcpStream>, bool)> {
172    let tls_start = Instant::now();
173    let mut root_store = RootCertStore::empty();
174    let certs = rustls_native_certs::load_native_certs().certs;
175
176    for cert in certs {
177        root_store
178            .add(cert)
179            .map_err(|e| Error::Rustls { source: e })?;
180    }
181    let mut config = ClientConfig::builder()
182        .with_root_certificates(root_store)
183        .with_no_client_auth();
184
185    // Skip certificate verification if requested
186    if skip_verify {
187        config
188            .dangerous()
189            .set_certificate_verifier(Arc::new(SkipVerifier));
190    }
191
192    config.alpn_protocols = alpn_protocols
193        .iter()
194        .map(|s| s.as_bytes().to_vec())
195        .collect();
196
197    let connector = TlsConnector::from(Arc::new(config));
198
199    // Perform TLS handshake
200    let tls_stream = timeout(
201        Duration::from_secs(30),
202        connector.connect(
203            host.clone()
204                .try_into()
205                .map_err(|e| Error::InvalidDnsName { source: e })?,
206            tcp_stream,
207        ),
208    )
209    .await
210    .map_err(|e| Error::Timeout { source: e })?
211    .map_err(|e| Error::Io { source: e })?;
212    stat.tls_handshake = Some(tls_start.elapsed());
213
214    let (_, session) = tls_stream.get_ref();
215
216    stat.tls = session
217        .protocol_version()
218        .map(|v| v.as_str().unwrap_or_default().to_string());
219
220    if let Some(certs) = session.peer_certificates() {
221        if let Some(cert) = certs.first() {
222            if let Ok((_, cert)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
223                stat.cert_not_before = Some(cert.validity().not_before.to_string());
224                stat.cert_not_after = Some(cert.validity().not_after.to_string());
225                if let Ok(Some(sans)) = cert.subject_alternative_name() {
226                    let mut domains = Vec::new();
227                    for san in sans.value.general_names.iter() {
228                        if let x509_parser::extensions::GeneralName::DNSName(domain) = san {
229                            domains.push(domain.to_string());
230                        }
231                    }
232                    stat.cert_domains = Some(domains);
233                };
234            }
235        }
236    }
237    if let Some(cipher) = session.negotiated_cipher_suite() {
238        stat.cert_cipher = Some(format!("{:?}", cipher));
239    }
240    let mut is_http2 = false;
241    if let Some(protocol) = session.alpn_protocol() {
242        let alpn = String::from_utf8_lossy(protocol).to_string();
243        is_http2 = alpn == ALPN_HTTP2;
244        stat.alpn = Some(alpn);
245    }
246    Ok((tls_stream, is_http2))
247}
248
249async fn send_http_request(
250    req: Request<Full<Bytes>>,
251    tcp_stream: TcpStream,
252    tx: oneshot::Sender<String>,
253    stat: &mut HttpStat,
254) -> Result<Response<Incoming>> {
255    let (mut sender, conn) = timeout(
256        Duration::from_secs(30),
257        hyper::client::conn::http1::handshake(TokioIo::new(tcp_stream)),
258    )
259    .await
260    .map_err(|e| Error::Timeout { source: e })?
261    .map_err(|e| Error::Hyper { source: e })?;
262    // Spawn the connection task
263    tokio::spawn(async move {
264        if let Err(e) = conn.await {
265            let _ = tx.send(e.to_string());
266        }
267    });
268
269    let server_processing_start = Instant::now();
270    let resp = sender
271        .send_request(req)
272        .await
273        .map_err(|e| Error::Hyper { source: e })?;
274    stat.server_processing = Some(server_processing_start.elapsed());
275    Ok(resp)
276}
277
278async fn send_https_request(
279    req: Request<Full<Bytes>>,
280    tls_stream: TlsStream<TcpStream>,
281    tx: oneshot::Sender<String>,
282    stat: &mut HttpStat,
283) -> Result<Response<Incoming>> {
284    let (mut sender, conn) = timeout(
285        Duration::from_secs(30),
286        hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)),
287    )
288    .await
289    .map_err(|e| Error::Timeout { source: e })?
290    .map_err(|e| Error::Hyper { source: e })?;
291    // Spawn the connection task
292    tokio::spawn(async move {
293        if let Err(e) = conn.await {
294            let _ = tx.send(e.to_string());
295        }
296    });
297
298    let server_processing_start = Instant::now();
299    let resp = sender
300        .send_request(req)
301        .await
302        .map_err(|e| Error::Hyper { source: e })?;
303    stat.server_processing = Some(server_processing_start.elapsed());
304    Ok(resp)
305}
306
307async fn send_https2_request(
308    req: Request<Full<Bytes>>,
309    tls_stream: TlsStream<TcpStream>,
310    tx: oneshot::Sender<String>,
311    stat: &mut HttpStat,
312) -> Result<Response<Incoming>> {
313    let (mut sender, conn) = timeout(
314        Duration::from_secs(30),
315        hyper::client::conn::http2::handshake(TokioExecutor::new(), TokioIo::new(tls_stream)),
316    )
317    .await
318    .map_err(|e| Error::Timeout { source: e })?
319    .map_err(|e| Error::Hyper { source: e })?;
320    // Spawn the connection task
321    tokio::spawn(async move {
322        if let Err(e) = conn.await {
323            let _ = tx.send(e.to_string());
324        }
325    });
326
327    let mut req = req;
328    *req.version_mut() = hyper::Version::HTTP_2;
329    // Remove Host header for HTTP/2 as it's replaced by :authority
330    req.headers_mut().remove("Host");
331
332    let server_processing_start = Instant::now();
333    let resp = sender
334        .send_request(req)
335        .await
336        .map_err(|e| Error::Hyper { source: e })?;
337    stat.server_processing = Some(server_processing_start.elapsed());
338    Ok(resp)
339}
340
341fn finish_with_error(mut stat: HttpStat, error: impl ToString, start: Instant) -> HttpStat {
342    stat.error = Some(error.to_string());
343    stat.total = Some(start.elapsed());
344    stat
345}
346
347pub async fn request(http_req: HttpRequest) -> HttpStat {
348    ensure_crypto_provider();
349    let start = Instant::now();
350    let mut stat = HttpStat::default();
351
352    // DNS resolution
353    let dns_result = dns_resolve(&http_req, &mut stat).await;
354    let (addr, host) = match dns_result {
355        Ok(result) => result,
356        Err(e) => {
357            return finish_with_error(stat, e, start);
358        }
359    };
360
361    // TCP connection
362    let tcp_stream = match tcp_connect(addr, &mut stat).await {
363        Ok(stream) => stream,
364        Err(e) => {
365            return finish_with_error(stat, e, start);
366        }
367    };
368
369    let uri = http_req.uri;
370    let is_https = uri.scheme() == Some(&http::uri::Scheme::HTTPS);
371    let mut builder = Request::builder()
372        .uri(&uri)
373        .method(http_req.method.unwrap_or(Method::GET));
374    let mut set_host = false;
375    let mut set_user_agent = false;
376    if let Some(headers) = http_req.headers {
377        for (key, value) in headers.iter() {
378            builder = builder.header(key, value);
379            match key.to_string().to_lowercase().as_str() {
380                "host" => set_host = true,
381                "user-agent" => set_user_agent = true,
382                _ => {}
383            }
384        }
385    }
386    if !set_host {
387        builder = builder.header("Host", host.clone());
388    }
389    if !set_user_agent {
390        builder = builder.header("User-Agent", format!("httpstat.rs/{}", VERSION));
391    }
392
393    // Build the request
394    let req = match builder.body(Full::new(http_req.body.unwrap_or_default())) {
395        Ok(req) => req,
396        Err(e) => {
397            return finish_with_error(stat, format!("Failed to build request: {}", e), start);
398        }
399    };
400
401    // Create a channel to receive connection errors
402    let (tx, mut rx) = oneshot::channel();
403
404    // Send the request based on protocol
405    let resp = if is_https {
406        // TLS handshake
407        let tls_result = tls_handshake(
408            host.clone(),
409            tcp_stream,
410            http_req.alpn_protocols,
411            http_req.skip_verify,
412            &mut stat,
413        )
414        .await;
415        let (tls_stream, is_http2) = match tls_result {
416            Ok(result) => result,
417            Err(e) => {
418                return finish_with_error(stat, e, start);
419            }
420        };
421
422        // Send HTTPS request
423        if is_http2 {
424            match send_https2_request(req, tls_stream, tx, &mut stat).await {
425                Ok(resp) => resp,
426                Err(e) => {
427                    return finish_with_error(stat, e, start);
428                }
429            }
430        } else {
431            match send_https_request(req, tls_stream, tx, &mut stat).await {
432                Ok(resp) => resp,
433                Err(e) => {
434                    return finish_with_error(stat, e, start);
435                }
436            }
437        }
438    } else {
439        // Send HTTP request
440        match send_http_request(req, tcp_stream, tx, &mut stat).await {
441            Ok(resp) => resp,
442            Err(e) => {
443                return finish_with_error(stat, e, start);
444            }
445        }
446    };
447
448    stat.status = Some(resp.status());
449    stat.headers = Some(resp.headers().clone());
450
451    // Read the response body
452    let content_transfer_start = Instant::now();
453    let body_result = resp.collect().await;
454    let body = match body_result {
455        Ok(body) => body,
456        Err(e) => {
457            return finish_with_error(stat, format!("Failed to read response body: {}", e), start);
458        }
459    };
460
461    let body_bytes = body.to_bytes();
462    if let Some(output) = http_req.output {
463        match fs::write(output, body_bytes).await {
464            Ok(_) => {}
465            Err(e) => {
466                return finish_with_error(
467                    stat,
468                    format!("Failed to write response body to file: {}", e),
469                    start,
470                );
471            }
472        }
473    } else {
474        stat.body = Some(body_bytes);
475    }
476    stat.content_transfer = Some(content_transfer_start.elapsed());
477
478    // Check for connection errors
479    if let Ok(error) = rx.try_recv() {
480        stat.error = Some(error);
481    }
482
483    stat.total = Some(start.elapsed());
484    stat
485}