use crate::config::Config;
use crate::dns::DnsResolver;
use crate::error::Error;
use crate::headers::{HeaderName, Headers};
use crate::method::Method;
use crate::parser::RequestBuilder as ParserRequestBuilder;
use crate::parser::uri::Uri;
use crate::socket::BlockingSocket;
use crate::transport::{ConnectionPool, Connector, PoolKey, RawResponse, ResponseBodyExpectation};
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
pub struct RequestExecutor<'a, S, D> {
pool: &'a Arc<ConnectionPool<S>>,
dns: &'a D,
config: &'a Config,
}
impl<'a, S, D> RequestExecutor<'a, S, D>
where
S: BlockingSocket,
D: DnsResolver,
{
pub const fn new(
pool: &'a Arc<ConnectionPool<S>>,
dns: &'a D,
config: &'a Config,
) -> Self {
Self { pool, dns, config }
}
pub fn execute(
&self,
uri: &Uri,
method: Method,
custom_headers: &Headers,
body: Option<&[u8]>,
) -> Result<RawResponse, Error> {
let host_str = Self::extract_host_from_uri(uri)?;
let port = Self::extract_port_from_uri(uri);
let pool_key = PoolKey::new(host_str.clone(), port);
let mut socket = self.get_or_create_socket(&pool_key)?;
let connector = Connector::new(&mut socket, self.dns);
let mut conn = connector.connect(uri, self.config)?;
let request_bytes = self.build_request(uri, method, &host_str, port, custom_headers, body)?;
conn.send_request(&request_bytes)?;
let expectation = if method == Method::Head {
ResponseBodyExpectation::NoBody
} else {
ResponseBodyExpectation::Normal
};
let raw = conn.read_raw_response(expectation)?;
self.handle_connection_reuse(conn.is_reusable(), pool_key, socket);
Ok(raw)
}
fn extract_host_from_uri(uri: &Uri) -> Result<String, Error> {
let authority = uri.authority();
authority.map_or_else(
|| Ok(String::new()),
|auth| match auth.host() {
crate::parser::uri::Host::RegName(name) => Ok(String::from(*name)),
crate::parser::uri::Host::IpAddr(_) => Err(Error::IpAddressNotSupported),
},
)
}
fn extract_port_from_uri(uri: &Uri) -> u16 {
uri
.authority()
.and_then(super::super::parser::uri::Authority::port)
.unwrap_or_else(|| {
if uri.scheme() == "https" {
443
} else {
80
}
})
}
fn get_or_create_socket(
&self,
pool_key: &PoolKey,
) -> Result<S, Error> {
if self.config.connection_pooling {
self
.pool
.get(pool_key)
.map_or_else(|| S::new().map_err(Error::Socket), |s| Ok(s))
} else {
S::new().map_err(Error::Socket)
}
}
fn build_request(
&self,
uri: &Uri,
method: Method,
host_str: &str,
port: u16,
custom_headers: &Headers,
body: Option<&[u8]>,
) -> Result<Vec<u8>, Error> {
use alloc::format;
let host_header = if (uri.scheme() == "http" && port == 80) || (uri.scheme() == "https" && port == 443) {
String::from(host_str)
} else {
format!("{host_str}:{port}")
};
let mut builder =
ParserRequestBuilder::new(method.as_str(), &uri.path_and_query()).header(HeaderName::HOST, host_header.as_str());
if !self.config.connection_pooling {
builder = builder.header(HeaderName::CONNECTION, "close");
}
if let Some(ref user_agent) = self.config.user_agent {
builder = builder.header(HeaderName::USER_AGENT, user_agent.as_str());
}
if let Some(ref accept) = self.config.accept
&& !custom_headers.contains(HeaderName::ACCEPT)
{
builder = builder.header(HeaderName::ACCEPT, accept.as_str());
}
if !custom_headers.contains(HeaderName::ACCEPT_ENCODING) {
#[allow(unused_mut)]
let mut encodings: Vec<&str> = Vec::new();
#[cfg(feature = "gzip-decompression")]
{
encodings.push("gzip");
encodings.push("deflate");
}
#[cfg(feature = "zstd-decompression")]
encodings.push("zstd");
if !encodings.is_empty() {
let accept_encoding = encodings.join(", ");
builder = builder.header(HeaderName::ACCEPT_ENCODING, accept_encoding.as_str());
}
}
for (name, value) in custom_headers {
builder = builder.header(name.as_str(), value.as_str());
}
if let Some(body_data) = body {
builder = builder.body(body_data.to_vec());
}
builder.build().map_err(Error::Parse)
}
fn handle_connection_reuse(
&self,
is_reusable: bool,
pool_key: PoolKey,
socket: S,
) {
if self.config.connection_pooling && is_reusable {
self.pool.return_connection(pool_key, socket);
}
}
}