use std::collections::HashMap;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::error::Error;
use crate::protocol::http::response::{PushedResponse, Response, ResponseHttpVersion};
use crate::throttle::{RateLimiter, SpeedLimits, THROTTLE_CHUNK_SIZE};
#[derive(Debug, Clone, Default)]
pub struct Http2Config {
pub window_size: Option<u32>,
pub connection_window_size: Option<u32>,
pub max_frame_size: Option<u32>,
pub max_header_list_size: Option<u32>,
pub enable_push: Option<bool>,
pub stream_weight: Option<u16>,
pub ping_interval: Option<Duration>,
}
pub async fn handshake<S>(
stream: S,
h2_config: &Http2Config,
) -> Result<h2::client::SendRequest<bytes::Bytes>, Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut builder = h2::client::Builder::new();
if let Some(window_size) = h2_config.window_size {
let _b = builder.initial_window_size(window_size);
}
if let Some(conn_window) = h2_config.connection_window_size {
let _b = builder.initial_connection_window_size(conn_window);
}
if let Some(max_frame) = h2_config.max_frame_size {
let _b = builder.max_frame_size(max_frame);
}
if let Some(max_header) = h2_config.max_header_list_size {
let _b = builder.max_header_list_size(max_header);
}
if let Some(enable_push) = h2_config.enable_push {
let _b = builder.enable_push(enable_push);
}
let (client, mut h2_conn): (
h2::client::SendRequest<bytes::Bytes>,
h2::client::Connection<S, bytes::Bytes>,
) = builder
.handshake(stream)
.await
.map_err(|e| Error::Http(format!("h2 handshake failed: {e}")))?;
if let Some(interval) = h2_config.ping_interval {
if let Some(ping_pong) = h2_conn.ping_pong() {
let _handle = tokio::spawn(ping_keepalive(ping_pong, interval));
}
}
let _handle = tokio::spawn(async move {
let _r = h2_conn.await;
});
Ok(client)
}
#[allow(clippy::too_many_arguments)]
pub async fn send_request(
client: h2::client::SendRequest<bytes::Bytes>,
method: &str,
host: &str,
request_target: &str,
custom_headers: &[(String, String)],
body: Option<&[u8]>,
url: &str,
speed_limits: &SpeedLimits,
) -> Result<(Response, h2::client::SendRequest<bytes::Bytes>), Error> {
let uri: http::Uri = format!("https://{host}{request_target}")
.parse()
.map_err(|e: http::uri::InvalidUri| Error::Http(format!("invalid URI: {e}")))?;
let method: http::Method = method
.parse()
.map_err(|e: http::method::InvalidMethod| Error::Http(format!("invalid method: {e}")))?;
let mut builder = http::Request::builder().method(method).uri(uri);
let has_user_agent = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("user-agent"));
if !has_user_agent {
builder = builder.header("user-agent", "curl/0.1.0");
}
let has_accept = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("accept"));
if !has_accept {
builder = builder.header("accept", "*/*");
}
for (name, value) in custom_headers {
builder = builder.header(name.as_str(), value.as_str());
}
let is_head = builder.method_ref().is_some_and(|m| m == http::Method::HEAD);
let has_body = body.is_some();
let req =
builder.body(()).map_err(|e| Error::Http(format!("failed to build h2 request: {e}")))?;
let mut client =
client.ready().await.map_err(|e| Error::Http(format!("h2 connection not ready: {e}")))?;
let (mut response_fut, mut send_stream): (
h2::client::ResponseFuture,
h2::SendStream<bytes::Bytes>,
) = client
.send_request(req, !has_body)
.map_err(|e| Error::Http(format!("h2 send failed: {e}")))?;
let mut push_stream = response_fut.push_promises();
let push_task = tokio::spawn(async move {
let mut pushed = Vec::new();
while let Some(result) = push_stream.push_promise().await {
match result {
Ok(push_promise) => {
let (req, resp_future) = push_promise.into_parts();
let push_url = req.uri().to_string();
if let Ok(h2_resp) = resp_future.await {
let status = h2_resp.status().as_u16();
let mut headers = HashMap::new();
for (name, value) in h2_resp.headers() {
let name = name.as_str().to_lowercase();
let value = String::from_utf8_lossy(value.as_bytes()).to_string();
let _old = headers.insert(name, value);
}
let mut body_stream = h2_resp.into_body();
let mut body_bytes = Vec::new();
while let Some(chunk) = body_stream.data().await {
match chunk {
Ok(data) => {
let chunk_len = data.len();
body_bytes.extend_from_slice(&data);
let _r = body_stream.flow_control().release_capacity(chunk_len);
}
Err(_) => break,
}
}
pushed.push(PushedResponse {
url: push_url,
status,
headers,
body: body_bytes,
});
}
}
Err(_) => break, }
}
pushed
});
if let Some(body_data) = body {
let mut send_limiter = RateLimiter::for_send(speed_limits);
if send_limiter.is_active() {
let mut offset = 0;
while offset < body_data.len() {
let end = (offset + THROTTLE_CHUNK_SIZE).min(body_data.len());
let is_last = end == body_data.len();
let chunk = body_data[offset..end].to_vec();
let chunk_len = chunk.len();
send_stream
.send_data(chunk.into(), is_last)
.map_err(|e| Error::Http(format!("h2 body send failed: {e}")))?;
send_limiter.record(chunk_len).await?;
offset = end;
}
} else {
send_stream
.send_data(body_data.to_vec().into(), true)
.map_err(|e| Error::Http(format!("h2 body send failed: {e}")))?;
}
}
let h2_response =
response_fut.await.map_err(|e| Error::Http(format!("h2 response error: {e}")))?;
let status = h2_response.status().as_u16();
let mut headers = HashMap::new();
let mut original_names = HashMap::with_capacity(h2_response.headers().len());
for (name, value) in h2_response.headers() {
let lower = name.as_str().to_lowercase();
let value = String::from_utf8_lossy(value.as_bytes()).to_string();
let _old = original_names.entry(lower.clone()).or_insert_with(|| name.as_str().to_string());
let _old = headers.insert(lower, value);
}
if is_head {
let pushed = tokio::time::timeout(Duration::from_millis(50), push_task)
.await
.map_or_else(|_| Vec::new(), std::result::Result::unwrap_or_default);
let mut resp = Response::new(status, headers, Vec::new(), url.to_string());
resp.set_header_original_names(original_names);
resp.set_http_version(ResponseHttpVersion::Http2);
if !pushed.is_empty() {
resp.set_pushed_responses(pushed);
}
return Ok((resp, client));
}
let mut recv_limiter = RateLimiter::for_recv(speed_limits);
let mut body_stream = h2_response.into_body();
let mut body_bytes = Vec::new();
while let Some(chunk) = body_stream.data().await {
let chunk = chunk.map_err(|e| Error::Http(format!("h2 body read error: {e}")))?;
let chunk_len = chunk.len();
body_bytes.extend_from_slice(&chunk);
let _r = body_stream.flow_control().release_capacity(chunk_len);
if recv_limiter.is_active() {
recv_limiter.record(chunk_len).await?;
}
}
let pushed = tokio::time::timeout(Duration::from_millis(50), push_task)
.await
.map_or_else(|_| Vec::new(), std::result::Result::unwrap_or_default);
let mut resp = Response::new(status, headers, body_bytes, url.to_string());
resp.set_header_original_names(original_names);
resp.set_http_version(ResponseHttpVersion::Http2);
if !pushed.is_empty() {
resp.set_pushed_responses(pushed);
}
Ok((resp, client))
}
#[allow(clippy::too_many_arguments)]
pub async fn request<S>(
stream: S,
method: &str,
host: &str,
request_target: &str,
custom_headers: &[(String, String)],
body: Option<&[u8]>,
url: &str,
speed_limits: &SpeedLimits,
h2_config: &Http2Config,
) -> Result<Response, Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let client = handshake(stream, h2_config).await?;
let (resp, _client) =
send_request(client, method, host, request_target, custom_headers, body, url, speed_limits)
.await?;
Ok(resp)
}
async fn ping_keepalive(mut ping_pong: h2::PingPong, interval: Duration) {
loop {
tokio::time::sleep(interval).await;
if ping_pong.ping(h2::Ping::opaque()).await.is_err() {
break;
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn pushed_response_struct() {
let pr = PushedResponse {
url: "/style.css".to_string(),
status: 200,
headers: HashMap::new(),
body: Vec::new(),
};
assert_eq!(pr.url, "/style.css");
assert_eq!(pr.status, 200);
}
#[test]
fn http2_config_default() {
let config = Http2Config::default();
assert!(config.window_size.is_none());
assert!(config.connection_window_size.is_none());
assert!(config.max_frame_size.is_none());
assert!(config.max_header_list_size.is_none());
assert!(config.enable_push.is_none());
assert!(config.stream_weight.is_none());
assert!(config.ping_interval.is_none());
}
#[test]
fn http2_config_custom() {
let config = Http2Config {
window_size: Some(1_048_576),
connection_window_size: Some(2_097_152),
max_frame_size: Some(32_768),
max_header_list_size: Some(8192),
enable_push: Some(false),
stream_weight: Some(128),
ping_interval: Some(Duration::from_secs(30)),
};
assert_eq!(config.window_size, Some(1_048_576));
assert_eq!(config.connection_window_size, Some(2_097_152));
assert_eq!(config.max_frame_size, Some(32_768));
assert_eq!(config.max_header_list_size, Some(8192));
assert_eq!(config.enable_push, Some(false));
assert_eq!(config.stream_weight, Some(128));
assert_eq!(config.ping_interval, Some(Duration::from_secs(30)));
}
#[test]
fn http2_config_clone() {
let original = Http2Config {
window_size: Some(65_535),
enable_push: Some(true),
..Http2Config::default()
};
#[allow(clippy::redundant_clone)]
let cloned = original.clone();
assert_eq!(cloned.window_size, Some(65_535));
assert_eq!(cloned.enable_push, Some(true));
assert!(cloned.max_frame_size.is_none());
}
#[test]
fn http2_config_debug() {
let config = Http2Config::default();
let debug = format!("{config:?}");
assert!(debug.contains("Http2Config"));
}
}