http_stat/
request.rs

1// Copyright 2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Copyright 2025 Tree xie.
16//
17// Licensed under the Apache License, Version 2.0 (the "License");
18// you may not use this file except in compliance with the License.
19// You may obtain a copy of the License at
20//
21// http://www.apache.org/licenses/LICENSE-2.0
22//
23// Unless required by applicable law or agreed to in writing, software
24// distributed under the License is distributed on an "AS IS" BASIS,
25// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26// See the License for the specific language governing permissions and
27// limitations under the License.
28
29use super::error::{Error, Result};
30use super::stats::{HttpStat, ALPN_HTTP1, ALPN_HTTP2};
31use super::SkipVerifier;
32use bytes::Bytes;
33use hickory_resolver::config::LookupIpStrategy;
34use hickory_resolver::name_server::TokioConnectionProvider;
35use hickory_resolver::TokioResolver;
36use http::HeaderValue;
37use http::Response;
38use http::Uri;
39use http::{HeaderMap, Method};
40use http_body_util::{BodyExt, Full};
41use hyper::body::Incoming;
42use hyper::Request;
43use hyper_util::rt::TokioExecutor;
44use hyper_util::rt::TokioIo;
45use std::collections::HashMap;
46use std::net::IpAddr;
47use std::net::SocketAddr;
48use std::sync::Arc;
49use std::sync::Once;
50use std::time::Duration;
51use std::time::Instant;
52use tokio::fs;
53use tokio::net::TcpStream;
54use tokio::sync::oneshot;
55use tokio::time::timeout;
56use tokio_rustls::client::TlsStream;
57use tokio_rustls::rustls::{ClientConfig, RootCertStore};
58use tokio_rustls::TlsConnector;
59
60#[derive(Default, Debug, Clone)]
61pub struct HttpRequest {
62    pub uri: Uri,
63    pub method: Option<Method>,
64    pub alpn_protocols: Vec<String>,
65    pub resolves: Option<HashMap<String, IpAddr>>,
66    pub headers: Option<HeaderMap<HeaderValue>>,
67    pub ip_version: Option<i32>, // -4 for IPv4, -6 for IPv6
68    pub skip_verify: bool,
69    pub output: Option<String>,
70    pub body: Option<Bytes>,
71}
72
73impl TryFrom<&str> for HttpRequest {
74    type Error = Error;
75
76    fn try_from(url: &str) -> Result<Self> {
77        let uri = url.parse::<Uri>().map_err(|e| Error::Uri { source: e })?;
78        Ok(Self {
79            uri,
80            method: None,
81            alpn_protocols: vec![ALPN_HTTP2.to_string(), ALPN_HTTP1.to_string()],
82            resolves: None,
83            headers: None,
84            ip_version: None,
85            skip_verify: false,
86            output: None,
87            body: None,
88        })
89    }
90}
91
92static INIT: Once = Once::new();
93
94fn ensure_crypto_provider() {
95    INIT.call_once(|| {
96        let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
97    });
98}
99
100async fn dns_resolve(req: &HttpRequest, stat: &mut HttpStat) -> Result<(SocketAddr, String)> {
101    let host = req
102        .uri
103        .host()
104        .ok_or(Error::Common {
105            category: "http".to_string(),
106            message: "host is required".to_string(),
107        })?
108        .to_string();
109    let default_port = if req.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
110        443
111    } else {
112        80
113    };
114    let port = req.uri.port_u16().unwrap_or(default_port);
115
116    // Check if we have a resolve entry for this host:port
117    if let Some(resolves) = &req.resolves {
118        let host_port = format!("{}:{}", host, port);
119        if let Some(ip) = resolves.get(&host_port) {
120            let addr = SocketAddr::new(*ip, port);
121            stat.addr = Some(addr.to_string());
122            return Ok((addr, host));
123        }
124    }
125
126    let provider = TokioConnectionProvider::default();
127    let mut builder = TokioResolver::builder(provider).map_err(|e| Error::Resolve { source: e })?;
128    if let Some(ip_version) = req.ip_version {
129        match ip_version {
130            4 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv4Only,
131            6 => builder.options_mut().ip_strategy = LookupIpStrategy::Ipv6Only,
132            _ => {}
133        }
134    }
135
136    let resolver = builder.build();
137    let dns_start = Instant::now();
138    let addr = resolver
139        .lookup_ip(&host)
140        .await
141        .map_err(|e| Error::Resolve { source: e })?;
142    stat.dns_lookup = Some(dns_start.elapsed());
143    let addr = addr.into_iter().next().ok_or(Error::Common {
144        category: "http".to_string(),
145        message: "dns lookup failed".to_string(),
146    })?;
147    let addr = SocketAddr::new(addr, port);
148    stat.addr = Some(addr.to_string());
149
150    Ok((addr, host))
151}
152
153async fn tcp_connect(addr: SocketAddr, stat: &mut HttpStat) -> Result<TcpStream> {
154    let tcp_start = Instant::now();
155    let tcp_stream = timeout(Duration::from_secs(10), TcpStream::connect(addr))
156        .await
157        .map_err(|e| Error::Timeout { source: e })?
158        .map_err(|e| Error::Io { source: e })?;
159    stat.tcp_connect = Some(tcp_start.elapsed());
160    Ok(tcp_stream)
161}
162
163async fn tls_handshake(
164    host: String,
165    tcp_stream: TcpStream,
166    alpn_protocols: Vec<String>,
167    skip_verify: bool,
168    stat: &mut HttpStat,
169) -> Result<(TlsStream<TcpStream>, bool)> {
170    let tls_start = Instant::now();
171    let mut root_store = RootCertStore::empty();
172    let certs = rustls_native_certs::load_native_certs().certs;
173
174    for cert in certs {
175        root_store
176            .add(cert)
177            .map_err(|e| Error::Rustls { source: e })?;
178    }
179    let mut config = ClientConfig::builder()
180        .with_root_certificates(root_store)
181        .with_no_client_auth();
182
183    // Skip certificate verification if requested
184    if skip_verify {
185        config
186            .dangerous()
187            .set_certificate_verifier(Arc::new(SkipVerifier));
188    }
189
190    config.alpn_protocols = alpn_protocols
191        .iter()
192        .map(|s| s.as_bytes().to_vec())
193        .collect();
194
195    let connector = TlsConnector::from(Arc::new(config));
196
197    // Perform TLS handshake
198    let tls_stream = connector
199        .connect(
200            host.clone()
201                .try_into()
202                .map_err(|e| Error::InvalidDnsName { source: e })?,
203            tcp_stream,
204        )
205        .await
206        .map_err(|e| Error::Io { source: e })?;
207    stat.tls_handshake = Some(tls_start.elapsed());
208
209    let (_, session) = tls_stream.get_ref();
210
211    stat.tls = session
212        .protocol_version()
213        .map(|v| v.as_str().unwrap_or_default().to_string());
214
215    if let Some(certs) = session.peer_certificates() {
216        if let Some(cert) = certs.first() {
217            if let Ok((_, cert)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
218                stat.cert_not_before = Some(cert.validity().not_before.to_string());
219                stat.cert_not_after = Some(cert.validity().not_after.to_string());
220                if let Ok(Some(sans)) = cert.subject_alternative_name() {
221                    let mut domains = Vec::new();
222                    for san in sans.value.general_names.iter() {
223                        if let x509_parser::extensions::GeneralName::DNSName(domain) = san {
224                            domains.push(domain.to_string());
225                        }
226                    }
227                    stat.cert_domains = Some(domains);
228                };
229            }
230        }
231    }
232    if let Some(cipher) = session.negotiated_cipher_suite() {
233        stat.cert_cipher = Some(format!("{:?}", cipher));
234    }
235    let mut is_http2 = false;
236    if let Some(protocol) = session.alpn_protocol() {
237        let alpn = String::from_utf8_lossy(protocol).to_string();
238        is_http2 = alpn == ALPN_HTTP2;
239        stat.alpn = Some(alpn);
240    }
241    Ok((tls_stream, is_http2))
242}
243
244async fn send_http_request(
245    req: Request<Full<Bytes>>,
246    tcp_stream: TcpStream,
247    tx: oneshot::Sender<String>,
248    stat: &mut HttpStat,
249) -> Result<Response<Incoming>> {
250    let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(tcp_stream))
251        .await
252        .map_err(|e| Error::Hyper { source: e })?;
253    // Spawn the connection task
254    tokio::spawn(async move {
255        if let Err(e) = conn.await {
256            let _ = tx.send(e.to_string());
257        }
258    });
259
260    let server_processing_start = Instant::now();
261    let resp = sender
262        .send_request(req)
263        .await
264        .map_err(|e| Error::Hyper { source: e })?;
265    stat.server_processing = Some(server_processing_start.elapsed());
266    Ok(resp)
267}
268
269async fn send_https_request(
270    req: Request<Full<Bytes>>,
271    tls_stream: TlsStream<TcpStream>,
272    tx: oneshot::Sender<String>,
273    stat: &mut HttpStat,
274) -> Result<Response<Incoming>> {
275    let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream))
276        .await
277        .map_err(|e| Error::Hyper { source: e })?;
278    // Spawn the connection task
279    tokio::spawn(async move {
280        if let Err(e) = conn.await {
281            let _ = tx.send(e.to_string());
282        }
283    });
284
285    let server_processing_start = Instant::now();
286    let resp = sender
287        .send_request(req)
288        .await
289        .map_err(|e| Error::Hyper { source: e })?;
290    stat.server_processing = Some(server_processing_start.elapsed());
291    Ok(resp)
292}
293
294async fn send_https2_request(
295    req: Request<Full<Bytes>>,
296    tls_stream: TlsStream<TcpStream>,
297    tx: oneshot::Sender<String>,
298    stat: &mut HttpStat,
299) -> Result<Response<Incoming>> {
300    let (mut sender, conn) =
301        hyper::client::conn::http2::handshake(TokioExecutor::new(), TokioIo::new(tls_stream))
302            .await
303            .map_err(|e| Error::Hyper { source: e })?;
304    // Spawn the connection task
305    tokio::spawn(async move {
306        if let Err(e) = conn.await {
307            let _ = tx.send(e.to_string());
308        }
309    });
310
311    let mut req = req;
312    *req.version_mut() = hyper::Version::HTTP_2;
313    // Remove Host header for HTTP/2 as it's replaced by :authority
314    req.headers_mut().remove("Host");
315
316    let server_processing_start = Instant::now();
317    let resp = sender
318        .send_request(req)
319        .await
320        .map_err(|e| Error::Hyper { source: e })?;
321    stat.server_processing = Some(server_processing_start.elapsed());
322    Ok(resp)
323}
324
325fn finish_with_error(mut stat: HttpStat, error: impl ToString, start: Instant) -> HttpStat {
326    stat.error = Some(error.to_string());
327    stat.total = Some(start.elapsed());
328    stat
329}
330
331pub async fn request(http_req: HttpRequest) -> HttpStat {
332    ensure_crypto_provider();
333    let start = Instant::now();
334    let mut stat = HttpStat::default();
335
336    // DNS resolution
337    let dns_result = dns_resolve(&http_req, &mut stat).await;
338    let (addr, host) = match dns_result {
339        Ok(result) => result,
340        Err(e) => {
341            return finish_with_error(stat, e, start);
342        }
343    };
344
345    // TCP connection
346    let tcp_stream = match tcp_connect(addr, &mut stat).await {
347        Ok(stream) => stream,
348        Err(e) => {
349            return finish_with_error(stat, e, start);
350        }
351    };
352
353    let uri = http_req.uri;
354    let is_https = uri.scheme() == Some(&http::uri::Scheme::HTTPS);
355    let mut builder = Request::builder()
356        .uri(&uri)
357        .method(http_req.method.unwrap_or(Method::GET));
358    let mut set_host = false;
359    if let Some(headers) = http_req.headers {
360        for (key, value) in headers.iter() {
361            builder = builder.header(key, value);
362            if key.to_string().to_lowercase() == "host" {
363                set_host = true;
364            }
365        }
366    }
367    if !set_host {
368        builder = builder.header("Host", host.clone());
369    }
370
371    // Build the request
372    let req = match builder.body(Full::new(http_req.body.unwrap_or_default())) {
373        Ok(req) => req,
374        Err(e) => {
375            return finish_with_error(stat, format!("Failed to build request: {}", e), start);
376        }
377    };
378
379    // Create a channel to receive connection errors
380    let (tx, mut rx) = oneshot::channel();
381
382    // Send the request based on protocol
383    let resp = if is_https {
384        // TLS handshake
385        let tls_result = tls_handshake(
386            host.clone(),
387            tcp_stream,
388            http_req.alpn_protocols,
389            http_req.skip_verify,
390            &mut stat,
391        )
392        .await;
393        let (tls_stream, is_http2) = match tls_result {
394            Ok(result) => result,
395            Err(e) => {
396                return finish_with_error(stat, e, start);
397            }
398        };
399
400        // Send HTTPS request
401        if is_http2 {
402            match send_https2_request(req, tls_stream, tx, &mut stat).await {
403                Ok(resp) => resp,
404                Err(e) => {
405                    return finish_with_error(stat, e, start);
406                }
407            }
408        } else {
409            match send_https_request(req, tls_stream, tx, &mut stat).await {
410                Ok(resp) => resp,
411                Err(e) => {
412                    return finish_with_error(stat, e, start);
413                }
414            }
415        }
416    } else {
417        // Send HTTP request
418        match send_http_request(req, tcp_stream, tx, &mut stat).await {
419            Ok(resp) => resp,
420            Err(e) => {
421                return finish_with_error(stat, e, start);
422            }
423        }
424    };
425
426    stat.status = Some(resp.status());
427    stat.headers = Some(resp.headers().clone());
428
429    // Read the response body
430    let content_transfer_start = Instant::now();
431    let body_result = resp.collect().await;
432    let body = match body_result {
433        Ok(body) => body,
434        Err(e) => {
435            return finish_with_error(stat, format!("Failed to read response body: {}", e), start);
436        }
437    };
438
439    let body_bytes = body.to_bytes();
440    if let Some(output) = http_req.output {
441        match fs::write(output, body_bytes).await {
442            Ok(_) => {}
443            Err(e) => {
444                return finish_with_error(
445                    stat,
446                    format!("Failed to write response body to file: {}", e),
447                    start,
448                );
449            }
450        }
451    } else {
452        stat.body = Some(body_bytes);
453    }
454    stat.content_transfer = Some(content_transfer_start.elapsed());
455
456    // Check for connection errors
457    if let Ok(error) = rx.try_recv() {
458        stat.error = Some(error);
459    }
460
461    stat.total = Some(start.elapsed());
462    stat
463}