#[cfg(test)]
mod tests;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::Duration;
#[cfg(any(feature = "http-client", feature = "http2"))]
use std::sync::Arc;
#[derive(Debug)]
pub struct HttpClientError(pub String);
impl std::fmt::Display for HttpClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for HttpClientError {}
struct ParsedUrl {
scheme: String,
host: String,
port: u16,
path_and_query: String,
}
impl ParsedUrl {
fn parse(url: &str) -> Result<Self, HttpClientError> {
let rest = if let Some(r) = url.strip_prefix("https://") {
("https", r)
} else if let Some(r) = url.strip_prefix("http://") {
("http", r)
} else {
return Err(HttpClientError(format!(
"unsupported or missing URL scheme in '{url}'"
)));
};
let (scheme, authority_and_path) = rest;
let default_port: u16 = if scheme == "https" { 443 } else { 80 };
let (authority, path_and_query) = match authority_and_path.find('/') {
Some(idx) => {
let (a, p) = authority_and_path.split_at(idx);
(a, p.to_string())
}
None => (authority_and_path, "/".to_string()),
};
let (host, port) = if let Some(bracket_end) = authority.find(']') {
let host = &authority[..=bracket_end];
let port_part = &authority[bracket_end + 1..];
let port = if let Some(p) = port_part.strip_prefix(':') {
p.parse::<u16>().map_err(|_| {
HttpClientError(format!("invalid port in URL '{url}'"))
})?
} else {
default_port
};
(host.to_string(), port)
} else {
match authority.rfind(':') {
Some(idx) => {
let port_str = &authority[idx + 1..];
let port = port_str.parse::<u16>().map_err(|_| {
HttpClientError(format!("invalid port in URL '{url}'"))
})?;
(authority[..idx].to_string(), port)
}
None => (authority.to_string(), default_port),
}
};
if host.is_empty() {
return Err(HttpClientError(format!("missing host in URL '{url}'")));
}
Ok(ParsedUrl {
scheme: scheme.to_string(),
host,
port,
path_and_query,
})
}
}
fn resolve_url(base_url: &str, location: &str) -> String {
if location.starts_with("http://") || location.starts_with("https://") {
return location.to_string();
}
if let Ok(base) = ParsedUrl::parse(base_url) {
let default_port = if base.scheme == "https" { 443 } else { 80 };
let port_str = if base.port == default_port {
String::new()
} else {
format!(":{}", base.port)
};
if location.starts_with('/') {
return format!("{}://{}{}{}", base.scheme, base.host, port_str, location);
}
let base_path = base.path_and_query;
let dir = match base_path.rfind('/') {
Some(i) => &base_path[..=i],
None => "/",
};
return format!(
"{}://{}{}{}{}",
base.scheme, base.host, port_str, dir, location
);
}
location.to_string()
}
#[derive(Debug)]
pub struct Response {
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
}
impl Response {
pub fn status(&self) -> u16 {
self.status
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn is_redirect(&self) -> bool {
matches!(self.status, 301 | 302 | 303 | 307 | 308)
}
pub fn header(&self, name: &str) -> Option<&str> {
let lower = name.to_lowercase();
self.headers
.iter()
.find(|(k, _)| k.to_lowercase() == lower)
.map(|(_, v)| v.as_str())
}
pub fn bytes(&self) -> &[u8] {
&self.body
}
pub fn text(&self) -> Result<String, HttpClientError> {
String::from_utf8(self.body.clone())
.map_err(|e| HttpClientError(format!("body is not valid UTF-8: {e}")))
}
#[cfg(feature = "serde")]
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, HttpClientError> {
serde_json::from_slice(&self.body)
.map_err(|e| HttpClientError(format!("JSON parse error: {e}")))
}
}
fn build_request_bytes(
method: &str,
path_and_query: &str,
host: &str,
headers: &[(String, String)],
body: &Option<Vec<u8>>,
) -> Vec<u8> {
let mut out: Vec<u8> = Vec::new();
let _ = write!(
out,
"{method} {path_and_query} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\nUser-Agent: rust-web-server/{}\r\n",
env!("CARGO_PKG_VERSION"),
);
if let Some(b) = body {
if !b.is_empty() {
let _ = write!(out, "Content-Length: {}\r\n", b.len());
}
}
for (k, v) in headers {
let _ = write!(out, "{k}: {v}\r\n");
}
out.extend_from_slice(b"\r\n");
if let Some(b) = body {
out.extend_from_slice(b);
}
out
}
fn read_response(stream: &mut dyn Read, is_head: bool) -> Result<Response, HttpClientError> {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = [0u8; 4096];
let header_end = loop {
let n = stream
.read(&mut tmp)
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 {
if buf.is_empty() {
return Err(HttpClientError(
"server closed connection without sending a response".into(),
));
}
break buf.len();
}
buf.extend_from_slice(&tmp[..n]);
if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let header_block = std::str::from_utf8(&buf[..header_end])
.map_err(|_| HttpClientError("response headers are not valid UTF-8".into()))?;
let mut lines = header_block.lines();
let status_line = lines
.next()
.ok_or_else(|| HttpClientError("empty response".into()))?;
let status = parse_status(status_line)?;
let response_headers: Vec<(String, String)> = lines
.filter_map(|line| {
let mut parts = line.splitn(2, ':');
let name = parts.next()?.trim().to_string();
let value = parts.next()?.trim().to_string();
if name.is_empty() {
None
} else {
Some((name, value))
}
})
.collect();
let mut body = buf[header_end..].to_vec();
if !is_head {
let transfer_encoding = response_headers
.iter()
.find(|(k, _)| k.to_lowercase() == "transfer-encoding")
.map(|(_, v)| v.to_lowercase());
let content_length: Option<usize> = response_headers
.iter()
.find(|(k, _)| k.to_lowercase() == "content-length")
.and_then(|(_, v)| v.trim().parse().ok());
if transfer_encoding
.as_deref()
.map(|te| te.contains("chunked"))
.unwrap_or(false)
{
loop {
let n = stream
.read(&mut tmp)
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 {
break;
}
body.extend_from_slice(&tmp[..n]);
}
body = decode_chunked(&body)?;
} else if let Some(len) = content_length {
while body.len() < len {
let n = stream
.read(&mut tmp)
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 {
break;
}
body.extend_from_slice(&tmp[..n]);
}
body.truncate(len);
} else {
loop {
let n = stream
.read(&mut tmp)
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 {
break;
}
body.extend_from_slice(&tmp[..n]);
}
}
} else {
body.clear();
}
Ok(Response {
status,
headers: response_headers,
body,
})
}
fn parse_status(line: &str) -> Result<u16, HttpClientError> {
let mut parts = line.splitn(3, ' ');
let _version = parts
.next()
.ok_or_else(|| HttpClientError("malformed status line".into()))?;
let code_str = parts
.next()
.ok_or_else(|| HttpClientError("missing status code".into()))?;
code_str
.parse::<u16>()
.map_err(|_| HttpClientError(format!("invalid status code '{code_str}'")))
}
fn decode_chunked(data: &[u8]) -> Result<Vec<u8>, HttpClientError> {
let mut out = Vec::new();
let mut pos = 0;
while pos < data.len() {
let line_end = data[pos..]
.windows(2)
.position(|w| w == b"\r\n")
.ok_or_else(|| HttpClientError("invalid chunked encoding: missing CRLF".into()))?;
let size_line = std::str::from_utf8(&data[pos..pos + line_end])
.map_err(|_| HttpClientError("chunked size is not ASCII".into()))?
.trim();
let size_str = size_line.split(';').next().unwrap_or("").trim();
let chunk_size = usize::from_str_radix(size_str, 16)
.map_err(|_| HttpClientError(format!("invalid chunk size '{size_str}'")))?;
pos += line_end + 2;
if chunk_size == 0 {
break; }
let end = pos + chunk_size;
if end > data.len() {
return Err(HttpClientError("chunked body truncated".into()));
}
out.extend_from_slice(&data[pos..end]);
pos = end + 2; }
Ok(out)
}
#[cfg(any(feature = "http-client", feature = "http2"))]
fn tls_connect(
host: &str,
tcp: TcpStream,
) -> Result<rustls::StreamOwned<rustls::ClientConnection, TcpStream>, HttpClientError> {
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
let server_name = ServerName::try_from(host.to_string())
.map_err(|e| HttpClientError(format!("invalid hostname '{host}': {e}")))?;
let conn = rustls::ClientConnection::new(config, server_name)
.map_err(|e| HttpClientError(e.to_string()))?;
Ok(rustls::StreamOwned::new(conn, tcp))
}
fn send_once(
method: &str,
parsed: &ParsedUrl,
headers: &[(String, String)],
body: &Option<Vec<u8>>,
timeout_ms: u64,
) -> Result<Response, HttpClientError> {
let addr = format!("{}:{}", parsed.host, parsed.port);
let timeout = Duration::from_millis(timeout_ms);
let sock_addr = addr
.parse::<std::net::SocketAddr>()
.or_else(|_| {
use std::net::ToSocketAddrs;
addr.to_socket_addrs()
.map_err(|e| HttpClientError(format!("DNS lookup for '{addr}' failed: {e}")))?
.next()
.ok_or_else(|| HttpClientError(format!("no address for '{addr}'")))
})
.map_err(|e: HttpClientError| e)?;
let tcp = TcpStream::connect_timeout(&sock_addr, timeout)
.map_err(|e| HttpClientError(format!("connect to '{addr}' failed: {e}")))?;
tcp.set_read_timeout(Some(timeout))
.map_err(|e| HttpClientError(e.to_string()))?;
tcp.set_write_timeout(Some(timeout))
.map_err(|e| HttpClientError(e.to_string()))?;
let request_bytes =
build_request_bytes(method, &parsed.path_and_query, &parsed.host, headers, body);
let is_head = method.eq_ignore_ascii_case("HEAD");
#[cfg(any(feature = "http-client", feature = "http2"))]
if parsed.scheme == "https" {
let mut tls_stream = tls_connect(&parsed.host, tcp)?;
tls_stream
.write_all(&request_bytes)
.map_err(|e| HttpClientError(format!("write error: {e}")))?;
return read_response(&mut tls_stream, is_head);
}
let mut stream = tcp;
stream
.write_all(&request_bytes)
.map_err(|e| HttpClientError(format!("write error: {e}")))?;
read_response(&mut stream, is_head)
}
pub struct Client {
timeout_ms: u64,
max_redirects: u8,
}
impl Client {
pub fn new() -> Self {
Self {
timeout_ms: 30_000,
max_redirects: 10,
}
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
pub fn max_redirects(mut self, n: u8) -> Self {
self.max_redirects = n;
self
}
pub fn get(&self, url: &str) -> RequestBuilder<'_> {
self.request("GET", url)
}
pub fn post(&self, url: &str) -> RequestBuilder<'_> {
self.request("POST", url)
}
pub fn put(&self, url: &str) -> RequestBuilder<'_> {
self.request("PUT", url)
}
pub fn patch(&self, url: &str) -> RequestBuilder<'_> {
self.request("PATCH", url)
}
pub fn delete(&self, url: &str) -> RequestBuilder<'_> {
self.request("DELETE", url)
}
pub fn head(&self, url: &str) -> RequestBuilder<'_> {
self.request("HEAD", url)
}
pub fn request(&self, method: &str, url: &str) -> RequestBuilder<'_> {
RequestBuilder {
client: self,
method: method.to_uppercase(),
url: url.to_string(),
headers: Vec::new(),
body: None,
timeout_ms: None,
}
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
pub struct RequestBuilder<'a> {
client: &'a Client,
method: String,
url: String,
headers: Vec<(String, String)>,
body: Option<Vec<u8>>,
timeout_ms: Option<u64>,
}
impl<'a> RequestBuilder<'a> {
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_string(), value.to_string()));
self
}
pub fn body(mut self, bytes: Vec<u8>) -> Self {
self.body = Some(bytes);
self
}
pub fn body_text(mut self, s: &str) -> Self {
self.headers
.push(("Content-Type".to_string(), "text/plain".to_string()));
self.body = Some(s.as_bytes().to_vec());
self
}
pub fn body_json(mut self, s: &str) -> Self {
self.headers.push((
"Content-Type".to_string(),
"application/json".to_string(),
));
self.body = Some(s.as_bytes().to_vec());
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = Some(ms);
self
}
pub fn send(self) -> Result<Response, HttpClientError> {
let timeout = self.timeout_ms.unwrap_or(self.client.timeout_ms);
let max_redirects = self.client.max_redirects;
let mut method = self.method;
let mut url = self.url;
let headers = self.headers;
let mut body = self.body;
let mut redirects = 0u8;
loop {
let parsed = ParsedUrl::parse(&url)?;
let resp = send_once(&method, &parsed, &headers, &body, timeout)?;
if resp.is_redirect() && redirects < max_redirects {
let location = resp
.header("location")
.ok_or_else(|| HttpClientError("redirect with no Location header".into()))?
.to_string();
url = resolve_url(&url, &location);
redirects += 1;
if matches!(resp.status(), 301 | 302 | 303) {
method = "GET".to_string();
body = None;
}
continue;
}
return Ok(resp);
}
}
}
#[cfg(feature = "http2")]
pub use async_impl::{AsyncClient, AsyncRequestBuilder};
#[cfg(feature = "http2")]
mod async_impl {
use super::{
build_request_bytes, decode_chunked, parse_status, resolve_url, HttpClientError,
ParsedUrl, Response,
};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn async_tls_connect(
host: &str,
stream: tokio::net::TcpStream,
) -> Result<tokio_rustls::client::TlsStream<tokio::net::TcpStream>, HttpClientError> {
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio_rustls::TlsConnector;
let root_store = rustls::RootCertStore::from_iter(
webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
);
let config = Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
let connector = TlsConnector::from(config);
let server_name = ServerName::try_from(host.to_string())
.map_err(|e| HttpClientError(format!("invalid hostname '{host}': {e}")))?;
connector
.connect(server_name, stream)
.await
.map_err(|e| HttpClientError(format!("TLS handshake failed: {e}")))
}
async fn async_read_response(
stream: &mut (impl AsyncReadExt + Unpin),
is_head: bool,
) -> Result<Response, HttpClientError> {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = vec![0u8; 4096];
let header_end = loop {
let n = stream
.read(&mut tmp)
.await
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 {
if buf.is_empty() {
return Err(HttpClientError(
"server closed connection without a response".into(),
));
}
break buf.len();
}
buf.extend_from_slice(&tmp[..n]);
if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let header_block = std::str::from_utf8(&buf[..header_end])
.map_err(|_| HttpClientError("response headers not UTF-8".into()))?;
let mut lines = header_block.lines();
let status_line = lines
.next()
.ok_or_else(|| HttpClientError("empty response".into()))?;
let status = parse_status(status_line)?;
let response_headers: Vec<(String, String)> = lines
.filter_map(|line| {
let mut parts = line.splitn(2, ':');
let name = parts.next()?.trim().to_string();
let value = parts.next()?.trim().to_string();
if name.is_empty() { None } else { Some((name, value)) }
})
.collect();
let mut body = buf[header_end..].to_vec();
if !is_head {
let transfer_encoding = response_headers
.iter()
.find(|(k, _)| k.to_lowercase() == "transfer-encoding")
.map(|(_, v)| v.to_lowercase());
let content_length: Option<usize> = response_headers
.iter()
.find(|(k, _)| k.to_lowercase() == "content-length")
.and_then(|(_, v)| v.trim().parse().ok());
if transfer_encoding
.as_deref()
.map(|te| te.contains("chunked"))
.unwrap_or(false)
{
loop {
let n = stream.read(&mut tmp).await
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 { break; }
body.extend_from_slice(&tmp[..n]);
}
body = decode_chunked(&body)?;
} else if let Some(len) = content_length {
while body.len() < len {
let n = stream.read(&mut tmp).await
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 { break; }
body.extend_from_slice(&tmp[..n]);
}
body.truncate(len);
} else {
loop {
let n = stream.read(&mut tmp).await
.map_err(|e| HttpClientError(format!("read error: {e}")))?;
if n == 0 { break; }
body.extend_from_slice(&tmp[..n]);
}
}
} else {
body.clear();
}
Ok(Response { status, headers: response_headers, body })
}
async fn async_send_once(
method: &str,
parsed: &ParsedUrl,
headers: &[(String, String)],
body: &Option<Vec<u8>>,
timeout_ms: u64,
) -> Result<Response, HttpClientError> {
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
let addr = format!("{}:{}", parsed.host, parsed.port);
let dur = Duration::from_millis(timeout_ms);
let request_bytes =
build_request_bytes(method, &parsed.path_and_query, &parsed.host, headers, body);
let is_head = method.eq_ignore_ascii_case("HEAD");
let tcp = timeout(dur, TcpStream::connect(&addr))
.await
.map_err(|_| HttpClientError(format!("connect to '{addr}' timed out")))?
.map_err(|e| HttpClientError(format!("connect to '{addr}' failed: {e}")))?;
if parsed.scheme == "https" {
let tls_stream = timeout(dur, async_tls_connect(&parsed.host, tcp))
.await
.map_err(|_| HttpClientError("TLS handshake timed out".into()))??;
let mut stream = tls_stream;
timeout(dur, stream.write_all(&request_bytes))
.await
.map_err(|_| HttpClientError("write timed out".into()))?
.map_err(|e| HttpClientError(format!("write error: {e}")))?;
return timeout(dur, async_read_response(&mut stream, is_head))
.await
.map_err(|_| HttpClientError("read timed out".into()))?;
}
let mut stream = tcp;
timeout(dur, stream.write_all(&request_bytes))
.await
.map_err(|_| HttpClientError("write timed out".into()))?
.map_err(|e| HttpClientError(format!("write error: {e}")))?;
timeout(dur, async_read_response(&mut stream, is_head))
.await
.map_err(|_| HttpClientError("read timed out".into()))?
}
pub struct AsyncClient {
timeout_ms: u64,
max_redirects: u8,
}
impl AsyncClient {
pub fn new() -> Self {
Self {
timeout_ms: 30_000,
max_redirects: 10,
}
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
pub fn max_redirects(mut self, n: u8) -> Self {
self.max_redirects = n;
self
}
pub fn get(&self, url: &str) -> AsyncRequestBuilder<'_> {
self.request("GET", url)
}
pub fn post(&self, url: &str) -> AsyncRequestBuilder<'_> {
self.request("POST", url)
}
pub fn put(&self, url: &str) -> AsyncRequestBuilder<'_> {
self.request("PUT", url)
}
pub fn patch(&self, url: &str) -> AsyncRequestBuilder<'_> {
self.request("PATCH", url)
}
pub fn delete(&self, url: &str) -> AsyncRequestBuilder<'_> {
self.request("DELETE", url)
}
pub fn request(&self, method: &str, url: &str) -> AsyncRequestBuilder<'_> {
AsyncRequestBuilder {
client: self,
method: method.to_uppercase(),
url: url.to_string(),
headers: Vec::new(),
body: None,
timeout_ms: None,
}
}
}
impl Default for AsyncClient {
fn default() -> Self {
Self::new()
}
}
pub struct AsyncRequestBuilder<'a> {
client: &'a AsyncClient,
method: String,
url: String,
headers: Vec<(String, String)>,
body: Option<Vec<u8>>,
timeout_ms: Option<u64>,
}
impl<'a> AsyncRequestBuilder<'a> {
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_string(), value.to_string()));
self
}
pub fn body(mut self, bytes: Vec<u8>) -> Self {
self.body = Some(bytes);
self
}
pub fn body_text(mut self, s: &str) -> Self {
self.headers
.push(("Content-Type".to_string(), "text/plain".to_string()));
self.body = Some(s.as_bytes().to_vec());
self
}
pub fn body_json(mut self, s: &str) -> Self {
self.headers.push((
"Content-Type".to_string(),
"application/json".to_string(),
));
self.body = Some(s.as_bytes().to_vec());
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = Some(ms);
self
}
pub async fn send(self) -> Result<Response, HttpClientError> {
let timeout = self.timeout_ms.unwrap_or(self.client.timeout_ms);
let max_redirects = self.client.max_redirects;
let mut method = self.method;
let mut url = self.url;
let headers = self.headers;
let mut body = self.body;
let mut redirects = 0u8;
loop {
let parsed = ParsedUrl::parse(&url)?;
let resp = async_send_once(&method, &parsed, &headers, &body, timeout).await?;
if resp.is_redirect() && redirects < max_redirects {
let location = resp
.header("location")
.ok_or_else(|| {
HttpClientError("redirect with no Location header".into())
})?
.to_string();
url = resolve_url(&url, &location);
redirects += 1;
if matches!(resp.status(), 301 | 302 | 303) {
method = "GET".to_string();
body = None;
}
continue;
}
return Ok(resp);
}
}
}
}