use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::TcpStream;
use std::sync::Arc;
use std::time::Duration;
use crate::error::{Error, Result};
use crate::net::{connector_from_proxy_url, Connector, DirectConnector, NetStream};
use crate::url::Url;
const DEFAULT_USER_AGENT: &str = concat!("rsurl/", env!("CARGO_PKG_VERSION"));
pub(crate) const MAX_HEADER_BYTES: usize = 64 * 1024;
pub(crate) const MAX_BODY_BYTES: usize = 256 * 1024 * 1024;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum HttpVersionPref {
#[default]
Auto,
Http11Only,
Http2Only,
Http3,
Http3Only,
}
#[derive(Debug, Clone)]
pub struct Request {
pub(crate) method: String,
pub(crate) url: Url,
pub(crate) headers: Vec<(String, String)>,
pub(crate) body: Vec<u8>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) read_timeout: Option<Duration>,
pub(crate) http_version_pref: HttpVersionPref,
pub(crate) follow_redirects: bool,
pub(crate) max_redirs: u32,
pub(crate) basic_auth: Option<(String, String)>,
pub(crate) verify_tls: bool,
pub(crate) ca_bundle: Option<String>,
pub(crate) max_time: Option<Duration>,
pub(crate) proxy: Option<ProxyConfig>,
pub(crate) no_proxy: Vec<String>,
pub(crate) idn: bool,
pub(crate) connector: Arc<dyn Connector>,
pub(crate) ip_family: Option<IpFamily>,
pub(crate) resolve: Vec<(String, u16, std::net::IpAddr)>,
pub(crate) auto_referer: bool,
pub(crate) redirect_trusted: bool,
pub(crate) keep_post: [bool; 3],
pub(crate) connect_to: Vec<(String, u16, String, u16)>,
pub(crate) tls_min: Option<crate::tls::ProtocolVersion>,
pub(crate) tls_max: Option<crate::tls::ProtocolVersion>,
pub(crate) auth_digest: bool,
pub(crate) client_cert: Option<String>,
pub(crate) client_key: Option<String>,
pub(crate) client_key_pass: Option<String>,
pub(crate) cert_is_der: bool,
pub(crate) key_is_der: bool,
pub(crate) pinned_pubkey: Option<String>,
pub(crate) ca_path: Option<String>,
pub(crate) crl_file: Option<String>,
pub(crate) ciphers: Option<String>,
pub(crate) tls13_ciphers: Option<String>,
pub(crate) cancel: Option<crate::CancelToken>,
pub(crate) strict_headers: bool,
pub(crate) keep_method_case: bool,
pub(crate) jar_through_redirects: bool,
pub(crate) tls_verify_callback: Option<crate::tls::VerifyCallback>,
pub(crate) resolver: Arc<dyn crate::net::Resolver>,
pub(crate) partition_key: Option<String>,
pub(crate) proxy_resolver: Option<Arc<dyn crate::net::ProxyResolver>>,
pub(crate) priority: Priority,
pub(crate) decompress: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpFamily {
V4,
V6,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Priority {
High,
#[default]
Normal,
Low,
}
impl Priority {
fn rank(self) -> u8 {
match self {
Priority::High => 0,
Priority::Normal => 1,
Priority::Low => 2,
}
}
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub host: String,
pub port: u16,
pub auth: Option<(String, String)>,
}
impl ProxyConfig {
pub fn parse(s: &str) -> Result<Self> {
let normalised: String = if s.contains("://") {
s.to_string()
} else {
format!("http://{s}")
};
let u = Url::parse(&normalised)?;
if u.scheme != "http" {
return Err(Error::UnsupportedScheme(format!(
"proxy scheme {:?} not supported (only http:// at this milestone)",
u.scheme
)));
}
let auth = u
.userinfo
.as_deref()
.map(|info| match info.split_once(':') {
Some((u, p)) => (u.to_string(), p.to_string()),
None => (info.to_string(), String::new()),
});
Ok(ProxyConfig {
host: u.host.clone(),
port: u.port,
auth,
})
}
}
impl Request {
pub fn new(method: &str, url: &str) -> Result<Self> {
Ok(Request {
method: method.to_string(),
url: Url::parse(url)?,
headers: Vec::new(),
body: Vec::new(),
connect_timeout: Some(Duration::from_secs(30)),
read_timeout: Some(Duration::from_secs(60)),
http_version_pref: HttpVersionPref::Auto,
follow_redirects: false,
max_redirs: 50,
basic_auth: None,
verify_tls: true,
ca_bundle: None,
max_time: None,
proxy: None,
no_proxy: Vec::new(),
idn: true,
connector: Arc::new(DirectConnector),
ip_family: None,
resolve: Vec::new(),
auto_referer: false,
redirect_trusted: false,
keep_post: [false; 3],
connect_to: Vec::new(),
tls_min: None,
tls_max: None,
auth_digest: false,
client_cert: None,
client_key: None,
client_key_pass: None,
cert_is_der: false,
key_is_der: false,
pinned_pubkey: None,
ca_path: None,
crl_file: None,
ciphers: None,
tls13_ciphers: None,
cancel: None,
strict_headers: false,
keep_method_case: false,
jar_through_redirects: true,
tls_verify_callback: None,
resolver: Arc::new(crate::net::StdResolver),
partition_key: None,
proxy_resolver: None,
priority: Priority::Normal,
decompress: true,
})
}
pub fn priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn proxy_resolver(mut self, resolver: Arc<dyn crate::net::ProxyResolver>) -> Self {
self.proxy_resolver = Some(resolver);
self
}
fn apply_proxy_resolver(&mut self) -> Result<()> {
if self.proxy.is_some() || !self.connector.is_direct() {
return Ok(());
}
let Some(resolver) = self.proxy_resolver.clone() else {
return Ok(());
};
if let crate::net::ProxyChoice::Proxy(spec) = resolver.resolve(&self.url) {
let is_http_proxy = match spec.split_once("://") {
Some((scheme, _)) => scheme.eq_ignore_ascii_case("http"),
None => true,
};
if is_http_proxy {
self.proxy = Some(ProxyConfig::parse(&spec)?);
} else {
self.connector = connector_from_proxy_url(&spec)?;
}
}
Ok(())
}
pub fn resolver(mut self, resolver: Arc<dyn crate::net::Resolver>) -> Self {
self.resolver = resolver;
self
}
pub fn partition(mut self, key: &str) -> Self {
self.partition_key = Some(key.to_string());
self
}
pub fn tls_verify_callback(mut self, cb: crate::tls::VerifyCallback) -> Self {
self.tls_verify_callback = Some(cb);
self
}
pub fn cookies_through_redirects(mut self, on: bool) -> Self {
self.jar_through_redirects = on;
self
}
pub fn strict_headers(mut self, on: bool) -> Self {
self.strict_headers = on;
self
}
pub fn keep_method_case(mut self, on: bool) -> Self {
self.keep_method_case = on;
self
}
pub fn cancel_token(mut self, token: crate::CancelToken) -> Self {
self.cancel = Some(token);
self
}
fn cancel_check(&self) -> Result<()> {
match &self.cancel {
Some(t) if t.is_cancelled() => Err(Error::Cancelled),
_ => Ok(()),
}
}
pub fn digest_auth(mut self, on: bool) -> Self {
self.auth_digest = on;
self
}
pub fn aws_sigv4(mut self, spec: &str, access: &str, secret: &str) -> Self {
let parts: Vec<&str> = spec.split(':').collect();
let labels: Vec<&str> = self.url.host.split('.').collect();
let region = parts
.get(2)
.filter(|s| !s.is_empty())
.copied()
.or_else(|| labels.get(1).copied())
.unwrap_or("us-east-1");
let service = parts
.get(3)
.filter(|s| !s.is_empty())
.copied()
.or_else(|| labels.first().copied())
.unwrap_or("s3");
let (path, query) = match self.url.path.split_once('?') {
Some((p, q)) => (p.to_string(), q.to_string()),
None => (self.url.path.clone(), String::new()),
};
let amz = crate::sigv4::amz_date_now();
let cfg = crate::sigv4::SigV4 {
access_key: access,
secret_key: secret,
region,
service,
};
let host = self.url.host.clone();
let method = self.method.clone();
let body = self.body.clone();
for (k, v) in crate::sigv4::sign(&cfg, &method, &host, &path, &query, &body, &amz) {
self.headers.retain(|(hk, _)| !hk.eq_ignore_ascii_case(&k));
self.headers.push((k, v));
}
self
}
pub fn tls_min_version(mut self, v: crate::tls::ProtocolVersion) -> Self {
self.tls_min = Some(v);
self
}
pub fn tls_max_version(mut self, v: crate::tls::ProtocolVersion) -> Self {
self.tls_max = Some(v);
self
}
pub fn connect_to(
mut self,
from_host: &str,
from_port: u16,
to_host: &str,
to_port: u16,
) -> Self {
self.connect_to.push((
from_host.to_string(),
from_port,
to_host.to_string(),
to_port,
));
self
}
pub fn redirect_trusted(mut self, on: bool) -> Self {
self.redirect_trusted = on;
self
}
pub fn keep_post_on(mut self, status: u16) -> Self {
if (301..=303).contains(&status) {
self.keep_post[(status - 301) as usize] = true;
}
self
}
pub fn auto_referer(mut self, on: bool) -> Self {
self.auto_referer = on;
self
}
pub fn ipv4(mut self) -> Self {
self.ip_family = Some(IpFamily::V4);
self
}
pub fn ipv6(mut self) -> Self {
self.ip_family = Some(IpFamily::V6);
self
}
pub fn resolve_addr(mut self, host: &str, port: u16, ip: std::net::IpAddr) -> Self {
self.resolve.push((host.to_string(), port, ip));
self
}
pub fn get(url: &str) -> Result<Self> {
Self::new("GET", url)
}
pub fn method(mut self, method: &str) -> Self {
self.method = method.to_string();
self
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_string(), value.to_string()));
self
}
pub fn body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn http_version(mut self, pref: HttpVersionPref) -> Self {
self.http_version_pref = pref;
self
}
pub fn http2_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http2Only;
self
}
pub fn http11_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http11Only;
self
}
pub fn http3(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http3;
self
}
pub fn http3_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http3Only;
self
}
pub fn follow_redirects(mut self, on: bool) -> Self {
self.follow_redirects = on;
self
}
pub fn max_redirs(mut self, n: u32) -> Self {
self.max_redirs = n;
self
}
pub fn basic_auth(mut self, user: &str, pass: &str) -> Self {
self.basic_auth = Some((user.to_string(), pass.to_string()));
self
}
pub fn verify_tls(mut self, on: bool) -> Self {
self.verify_tls = on;
self
}
pub fn idn(mut self, on: bool) -> Self {
self.idn = on;
self
}
pub fn decompress(mut self, on: bool) -> Self {
self.decompress = on;
self
}
pub fn ca_bundle(mut self, path: &str) -> Self {
self.ca_bundle = Some(path.to_string());
self
}
pub fn ca_path(mut self, dir: &str) -> Self {
self.ca_path = Some(dir.to_string());
self
}
pub fn crl_file(mut self, path: &str) -> Self {
self.crl_file = Some(path.to_string());
self
}
pub fn ciphers(mut self, list: &str) -> Self {
self.ciphers = Some(list.to_string());
self
}
pub fn tls13_ciphers(mut self, list: &str) -> Self {
self.tls13_ciphers = Some(list.to_string());
self
}
pub fn client_cert(mut self, path: &str) -> Self {
self.client_cert = Some(path.to_string());
self
}
pub fn client_key(mut self, path: &str) -> Self {
self.client_key = Some(path.to_string());
self
}
pub fn client_key_pass(mut self, pass: &str) -> Self {
self.client_key_pass = Some(pass.to_string());
self
}
pub fn cert_type_der(mut self, der: bool) -> Self {
self.cert_is_der = der;
self
}
pub fn key_type_der(mut self, der: bool) -> Self {
self.key_is_der = der;
self
}
pub fn pinned_pubkey(mut self, spec: &str) -> Self {
self.pinned_pubkey = Some(spec.to_string());
self
}
pub fn max_time(mut self, d: Duration) -> Self {
self.max_time = Some(d);
self
}
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.connect_timeout = Some(d);
self
}
pub fn read_timeout(mut self, d: Option<Duration>) -> Self {
self.read_timeout = d;
self
}
pub fn proxy(mut self, spec: &str) -> Result<Self> {
let is_http_proxy = match spec.split_once("://") {
Some((scheme, _)) => scheme.eq_ignore_ascii_case("http"),
None => true, };
if is_http_proxy {
self.proxy = Some(ProxyConfig::parse(spec)?);
} else {
self.connector = connector_from_proxy_url(spec)?;
}
Ok(self)
}
pub fn connector(mut self, connector: Arc<dyn Connector>) -> Self {
self.connector = connector;
self
}
pub fn proxy_user(mut self, user: &str, pass: &str) -> Result<Self> {
match self.proxy.as_mut() {
Some(p) => {
p.auth = Some((user.to_string(), pass.to_string()));
Ok(self)
}
None => Err(Error::BadResponse(
"proxy_user called without a proxy set".into(),
)),
}
}
pub fn no_proxy<I, S>(mut self, entries: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.no_proxy = entries.into_iter().map(Into::into).collect();
self
}
pub fn send(self) -> Result<Response> {
self.send_to(&mut io::sink(), None)
}
pub fn send_traced(self, trace: &mut dyn Write) -> Result<Response> {
self.send_to(trace, None)
}
pub fn send_with_jar(self, jar: &mut crate::cookie::CookieJar) -> Result<Response> {
self.send_to(&mut io::sink(), Some(jar))
}
pub fn send_traced_with_jar(
self,
jar: &mut crate::cookie::CookieJar,
trace: &mut dyn Write,
) -> Result<Response> {
self.send_to(trace, Some(jar))
}
pub fn send_streaming(
self,
on_head: impl FnMut(&ResponseHead),
on_chunk: impl FnMut(&[u8]) -> Result<()>,
) -> Result<Response> {
self.send_streaming_with(None, on_head, on_chunk)
}
pub fn send_streaming_with(
self,
jar: Option<&mut crate::cookie::CookieJar>,
mut on_head: impl FnMut(&ResponseHead),
mut on_chunk: impl FnMut(&[u8]) -> Result<()>,
) -> Result<Response> {
let mut head_obs = |h: &ResponseHead| on_head(h);
let mut sink = ChunkSink::new(&mut on_chunk);
let r = self.send_download_observed(&mut sink, jar, &mut io::sink(), Some(&mut head_obs));
sink.into_result(r)
}
pub fn send_download(
self,
sink: &mut dyn Write,
jar: Option<&mut crate::cookie::CookieJar>,
trace: &mut dyn Write,
) -> Result<Response> {
self.send_download_observed(sink, jar, trace, None)
}
pub fn send_reader(self) -> Result<BodyReader> {
self.send_reader_traced(&mut io::sink())
}
pub fn send_reader_traced(self, trace: &mut dyn Write) -> Result<BodyReader> {
let mut req = self;
req.url.set_idn(req.idn)?;
req.apply_proxy_resolver()?;
let streamable = req.connector.is_direct()
&& req.proxy.is_none()
&& !matches!(
req.http_version_pref,
HttpVersionPref::Http2Only | HttpVersionPref::Http3Only
)
&& (req.url.scheme == "http" || req.url.scheme == "https");
if !streamable {
return buffered_body_reader(req, trace);
}
req.cancel_check()?;
let mut hops_left = req.max_redirs;
loop {
let (head, mut bufrd, cancel) = open_h1_for_reader(&req, trace)?;
if req.follow_redirects && is_redirect_status(head.status) {
if let Some(loc) = head
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("location"))
.map(|(_, v)| v.clone())
{
let _ = read_body(
&mut bufrd,
&head.headers,
&head.version,
head.status,
&req.method,
)?;
drop(bufrd);
drop(cancel);
if hops_left == 0 {
return Err(Error::BadResponse(format!(
"maximum ({}) redirects followed",
req.max_redirs
)));
}
hops_left -= 1;
req = redirect_request(req, &head, &loc)?;
continue;
}
}
return build_body_reader(&req, head, bufrd, cancel);
}
}
fn send_download_observed(
self,
sink: &mut dyn Write,
jar: Option<&mut crate::cookie::CookieJar>,
trace: &mut dyn Write,
on_head: Option<HeadObserver>,
) -> Result<Response> {
let cancel = self.cancel.clone();
let started = std::time::Instant::now();
let mut r = self.send_download_dispatch(sink, jar, trace, on_head);
if let Ok(resp) = &mut r {
resp.timing.total = Some(started.elapsed());
}
map_cancelled(&cancel, r)
}
fn send_download_dispatch(
mut self,
sink: &mut dyn Write,
mut jar: Option<&mut crate::cookie::CookieJar>,
trace: &mut dyn Write,
mut on_head: Option<HeadObserver>,
) -> Result<Response> {
self.cancel_check()?;
self.apply_proxy_resolver()?;
let direct = self.connector.is_direct() && self.proxy.is_none();
let h1_streamable = direct
&& (self.url.scheme == "http"
|| matches!(self.http_version_pref, HttpVersionPref::Http11Only));
if h1_streamable {
return self.stream_h1(sink, jar, trace, on_head);
}
let h2_streamable = direct
&& self.url.scheme == "https"
&& !self.follow_redirects
&& matches!(
self.http_version_pref,
HttpVersionPref::Auto | HttpVersionPref::Http2Only
);
if h2_streamable {
let mut req = self;
req.url.set_idn(req.idn)?;
if let Some(j) = jar.as_deref_mut() {
j.purge_expired();
req.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("cookie"));
if let Some(val) = j.cookie_header(&req.url) {
req.headers.push(("Cookie".to_string(), val));
}
}
let url = req.url.clone();
let force_h2 = matches!(req.http_version_pref, HttpVersionPref::Http2Only);
let fallback = (!force_h2).then(|| {
let mut fb = req.clone();
fb.http_version_pref = HttpVersionPref::Http11Only;
fb
});
let head_reborrow: Option<HeadObserver> = match &mut on_head {
Some(f) => Some(&mut **f),
None => None,
};
match crate::http2::send_to(req, sink, head_reborrow, trace) {
Ok(resp) => {
if let Some(j) = jar {
j.ingest_response(&url, &resp.headers);
}
return Ok(resp);
}
Err(Error::H2NotNegotiated) => {
if let Some(fb) = fallback {
return fb.send_download_observed(sink, jar, trace, on_head);
}
return Err(Error::H2NotNegotiated);
}
Err(e) => return Err(e),
}
}
let h3_streamable = direct
&& self.url.scheme == "https"
&& !self.follow_redirects
&& matches!(
self.http_version_pref,
HttpVersionPref::Http3 | HttpVersionPref::Http3Only
);
if h3_streamable {
let mut req = self;
req.url.set_idn(req.idn)?;
if let Some(j) = jar.as_deref_mut() {
j.purge_expired();
req.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("cookie"));
if let Some(val) = j.cookie_header(&req.url) {
req.headers.push(("Cookie".to_string(), val));
}
}
let url = req.url.clone();
let force_h3 = matches!(req.http_version_pref, HttpVersionPref::Http3Only);
let fallback = (!force_h3).then(|| {
let mut fb = req.clone();
fb.http_version_pref = HttpVersionPref::Auto;
fb
});
let head_reborrow: Option<HeadObserver> = match &mut on_head {
Some(f) => Some(&mut **f),
None => None,
};
match crate::http3::send_to(req, sink, head_reborrow, trace) {
Ok(resp) => {
if let Some(j) = jar {
j.ingest_response(&url, &resp.headers);
}
return Ok(resp);
}
Err(e) if fallback.is_some() && h3_should_fall_back(&e) => {
let _ = writeln!(trace, "* HTTP/3 failed ({e}); falling back");
return fallback
.unwrap()
.send_download_observed(sink, jar, trace, on_head);
}
Err(e) => return Err(e),
}
}
self.send_download_buffered(sink, jar, trace, on_head)
}
fn send_download_buffered(
self,
sink: &mut dyn Write,
jar: Option<&mut crate::cookie::CookieJar>,
trace: &mut dyn Write,
on_head: Option<HeadObserver>,
) -> Result<Response> {
let resp = self.send_to(trace, jar)?;
if let Some(obs) = on_head {
obs(&ResponseHead {
status: resp.status,
reason: resp.reason.clone(),
version: resp.version.clone(),
headers: resp.headers.clone(),
});
}
sink.write_all(&resp.body)?;
Ok(Response {
body: Vec::new(),
..resp
})
}
fn stream_h1(
self,
sink: &mut dyn Write,
mut jar: Option<&mut crate::cookie::CookieJar>,
trace: &mut dyn Write,
mut on_head: Option<HeadObserver>,
) -> Result<Response> {
let mut req = self;
req.url.set_idn(req.idn)?;
let mut hops_left = req.max_redirs;
let mut first_hop = true;
loop {
req.cancel_check()?;
let use_jar = req.jar_through_redirects || first_hop;
let start = std::time::Instant::now();
let mut timing = Timing::default();
let _cancel_guard;
let mut tls_info: Option<TlsInfo> = None;
let stream: Box<dyn Rw> = if req.url.scheme == "https" {
let (tcp, guard, namelookup) = tcp_connect_cancellable(&req, trace)?;
_cancel_guard = guard;
timing.namelookup = namelookup;
timing.connect = Some(start.elapsed());
let opts = tls_opts_from(&req, &[])?;
let tls = crate::tls::connect_over_tls(tcp, &req.url.host, opts)?;
let appconnect = start.elapsed();
timing.appconnect = Some(appconnect);
timing.pretransfer = Some(appconnect);
write_tls_info(&tls, trace);
tls_info = Some(tls_info_from(&tls));
Box::new(tls)
} else {
let (tcp, guard, namelookup) = tcp_connect_cancellable(&req, trace)?;
_cancel_guard = guard;
timing.namelookup = namelookup;
let connect = start.elapsed();
timing.connect = Some(connect);
timing.pretransfer = Some(connect);
Box::new(tcp)
};
let mut bufrd = BufReader::new(stream);
let mut snapshot = req.clone();
if let Some(j) = jar.as_deref_mut().filter(|_| use_jar) {
j.purge_expired();
snapshot
.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("cookie"));
if let Some(val) = j.cookie_header(&snapshot.url) {
snapshot.headers.push(("Cookie".to_string(), val));
}
}
write_request(
bufrd.get_mut(),
&snapshot,
via_plain_http_proxy(&req),
trace,
)?;
let head = read_head(&mut bufrd, trace)?;
timing.starttransfer = Some(start.elapsed());
if let Some(j) = jar.as_deref_mut().filter(|_| use_jar) {
j.ingest_response(&req.url, &head.headers);
}
if req.follow_redirects && is_redirect_status(head.status) {
if let Some((_, loc)) = head
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("location"))
{
let location = loc.clone();
let _ = read_body(
&mut bufrd,
&head.headers,
&head.version,
head.status,
&req.method,
)?;
if hops_left == 0 {
return Err(Error::BadResponse(format!(
"maximum ({}) redirects followed",
req.max_redirs
)));
}
let mut next_url = crate::url::resolve(&req.url, &location)?;
if next_url.scheme != "http" && next_url.scheme != "https" {
return Err(Error::UnsupportedScheme(next_url.scheme.clone()));
}
next_url.set_idn(req.idn)?;
let _ = writeln!(
trace,
"* Following redirect to {}",
url_to_string(&next_url)
);
let host_changed = next_url.host != req.url.host
|| next_url.port != req.url.port
|| next_url.scheme != req.url.scheme;
let prev_method = req.method.clone();
let prev_url = url_to_string(&req.url);
let prev_body = std::mem::take(&mut req.body);
let status = head.status;
let mut next = req;
next.url = next_url;
if next.auto_referer {
next.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("referer"));
next.headers.push(("Referer".to_string(), prev_url));
}
if host_changed && !next.redirect_trusted {
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("authorization")
&& !k.eq_ignore_ascii_case("cookie")
});
next.basic_auth = None;
}
let keep_post = if (301..=303).contains(&status) {
next.keep_post[(status - 301) as usize]
} else {
false
};
if (301..=303).contains(&status)
&& !keep_post
&& !prev_method.eq_ignore_ascii_case("GET")
&& !prev_method.eq_ignore_ascii_case("HEAD")
{
next.method = "GET".to_string();
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("content-type")
&& !k.eq_ignore_ascii_case("content-length")
&& !k.eq_ignore_ascii_case("transfer-encoding")
});
} else {
next.body = prev_body;
}
hops_left -= 1;
req = next;
first_hop = false;
continue;
}
}
if let Some(obs) = on_head.take() {
obs(&ResponseHead {
status: head.status,
reason: head.reason.clone(),
version: head.version.clone(),
headers: head.headers.clone(),
});
}
let content_encoding = head
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
.map(|(_, v)| v.clone())
.filter(|_| req.decompress);
if let Some(ce) = content_encoding {
let chunked = head.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
let has_body = !(req.method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&head.status)
|| head.status == 204
|| head.status == 304);
if has_body && !chunked {
if let (Some(codec), Some(len)) = (
crate::compress::single_streamable_layer(&ce),
parse_content_length(&head.headers)?,
) {
if len > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse(format!("body too large: {len}")));
}
let n = crate::compress::stream_decode(
bufrd.by_ref().take(len),
codec,
sink,
MAX_BODY_BYTES as u64,
)?;
let _ = writeln!(trace, "* Stream-decoded {n} body bytes ({codec:?})");
return Ok(Response {
status: head.status,
reason: head.reason,
version: head.version,
headers: crate::compress::strip_after_decode(head.headers),
body: Vec::new(),
timing,
final_url: url_to_string(&req.url),
tls: tls_info.clone(),
});
}
}
let body = read_body(
&mut bufrd,
&head.headers,
&head.version,
head.status,
&req.method,
)?;
let (headers, body) = maybe_decode_body(head.headers, body, req.decompress, trace)?;
sink.write_all(&body)?;
return Ok(Response {
status: head.status,
reason: head.reason,
version: head.version,
headers,
body: Vec::new(),
timing,
final_url: url_to_string(&req.url),
tls: tls_info.clone(),
});
}
let n = stream_body(&mut bufrd, sink, &head.headers, head.status, &req.method)?;
let _ = writeln!(trace, "* Streamed {n} body bytes");
return Ok(Response {
status: head.status,
reason: head.reason,
version: head.version,
headers: head.headers,
body: Vec::new(),
timing,
final_url: url_to_string(&req.url),
tls: tls_info.clone(),
});
}
}
fn send_once(self, trace: &mut dyn Write) -> Result<Response> {
if !self.verify_tls && self.url.scheme == "https" {
let _ = writeln!(trace, "* WARNING: certificate verification disabled (-k)");
}
match self.url.scheme.as_str() {
"http" => send_plain(self, trace),
"https" => send_https(self, trace),
other => Err(Error::UnsupportedScheme(other.to_string())),
}
}
fn send_to(
self,
trace: &mut dyn Write,
jar: Option<&mut crate::cookie::CookieJar>,
) -> Result<Response> {
let cancel = self.cancel.clone();
let started = std::time::Instant::now();
let mut r = self.send_to_inner(trace, jar);
if let Ok(resp) = &mut r {
resp.timing.total = Some(started.elapsed());
}
map_cancelled(&cancel, r)
}
fn send_to_inner(
self,
trace: &mut dyn Write,
mut jar: Option<&mut crate::cookie::CookieJar>,
) -> Result<Response> {
let mut req = self;
req.url.set_idn(req.idn)?;
req.apply_proxy_resolver()?;
let digest_creds = if req.auth_digest {
req.basic_auth.take()
} else {
None
};
let mut digest_tried = false;
let deadline = req.max_time.map(|d| std::time::Instant::now() + d);
let mut hops_left = req.max_redirs;
let mut first_hop = true;
loop {
req.cancel_check()?;
let use_jar = req.jar_through_redirects || first_hop;
if let Some(end) = deadline {
if std::time::Instant::now() >= end {
return Err(Error::BadResponse("operation timed out".into()));
}
}
let mut snapshot = req.clone();
if let Some(j) = jar.as_deref_mut().filter(|_| use_jar) {
j.purge_expired();
snapshot
.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("cookie"));
if let Some(val) = j.cookie_header(&snapshot.url) {
snapshot.headers.push(("Cookie".to_string(), val));
}
}
let mut resp = snapshot.send_once(trace)?;
if let Some(j) = jar.as_deref_mut().filter(|_| use_jar) {
j.ingest_response(&req.url, &resp.headers);
}
if resp.status == 401 && !digest_tried {
if let (Some((u, p)), Some(chal)) =
(digest_creds.as_ref(), resp.header("www-authenticate"))
{
let scheme = chal.trim_start();
if scheme
.as_bytes()
.get(..6)
.is_some_and(|b| b.eq_ignore_ascii_case(b"digest"))
{
if let Some(h) =
crate::digest::authorization(u, p, &req.method, &req.url.path, chal)
{
req.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("authorization"));
req.headers.push(("Authorization".to_string(), h));
digest_tried = true;
continue;
}
}
}
}
if !req.follow_redirects || !is_redirect_status(resp.status) {
resp.final_url = url_to_string(&req.url);
return Ok(resp);
}
if hops_left == 0 {
return Err(Error::BadResponse(format!(
"maximum ({}) redirects followed",
req.max_redirs
)));
}
let location = match resp.header("location") {
Some(l) => l.to_string(),
None => {
resp.final_url = url_to_string(&req.url);
return Ok(resp);
}
};
let mut next_url = crate::url::resolve(&req.url, &location)?;
next_url.set_idn(req.idn)?;
let _ = writeln!(
trace,
"* Following redirect to {}",
url_to_string(&next_url)
);
let host_changed = next_url.host != req.url.host
|| next_url.port != req.url.port
|| next_url.scheme != req.url.scheme;
let prev_method = req.method.clone();
let prev_url = url_to_string(&req.url);
let prev_body = std::mem::take(&mut req.body);
let mut next = req;
next.url = next_url;
if next.auto_referer {
next.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("referer"));
next.headers.push(("Referer".to_string(), prev_url));
}
if host_changed && !next.redirect_trusted {
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("authorization") && !k.eq_ignore_ascii_case("cookie")
});
next.basic_auth = None;
}
let keep_post = if (301..=303).contains(&resp.status) {
next.keep_post[(resp.status - 301) as usize]
} else {
false
};
if (301..=303).contains(&resp.status)
&& !keep_post
&& !prev_method.eq_ignore_ascii_case("GET")
&& !prev_method.eq_ignore_ascii_case("HEAD")
{
next.method = "GET".to_string();
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("content-type")
&& !k.eq_ignore_ascii_case("content-length")
&& !k.eq_ignore_ascii_case("transfer-encoding")
});
} else {
next.body = prev_body;
}
hops_left -= 1;
req = next;
first_hop = false;
}
}
}
fn is_redirect_status(status: u16) -> bool {
matches!(status, 301 | 302 | 303 | 307 | 308)
}
fn map_cancelled(cancel: &Option<crate::CancelToken>, r: Result<Response>) -> Result<Response> {
match (cancel, &r) {
(Some(t), Err(_)) if t.is_cancelled() => Err(Error::Cancelled),
_ => r,
}
}
fn url_to_string(u: &Url) -> String {
let default = matches!((u.scheme.as_str(), u.port), ("http", 80) | ("https", 443));
if default {
format!("{}://{}{}", u.scheme, u.host, u.path)
} else {
format!("{}://{}:{}{}", u.scheme, u.host, u.port, u.path)
}
}
#[derive(Debug, Clone)]
pub struct TlsInfo {
pub version: Option<crate::tls::ProtocolVersion>,
pub cipher_suite: Option<u16>,
pub alpn: Option<Vec<u8>>,
pub peer_certificates: Vec<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: u16,
pub reason: String,
pub version: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub tls: Option<TlsInfo>,
pub timing: Timing,
pub final_url: String,
}
#[derive(Debug, Clone, Default)]
pub struct Timing {
pub namelookup: Option<Duration>,
pub connect: Option<Duration>,
pub appconnect: Option<Duration>,
pub pretransfer: Option<Duration>,
pub starttransfer: Option<Duration>,
pub total: Option<Duration>,
}
#[derive(Debug, Clone)]
pub struct ResponseHead {
pub status: u16,
pub reason: String,
pub version: String,
pub headers: Vec<(String, String)>,
}
impl ResponseHead {
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
}
pub(crate) type HeadObserver<'a> = &'a mut dyn FnMut(&ResponseHead);
struct ChunkSink<'a> {
on_chunk: &'a mut dyn FnMut(&[u8]) -> Result<()>,
err: Option<Error>,
}
impl<'a> ChunkSink<'a> {
fn new(on_chunk: &'a mut dyn FnMut(&[u8]) -> Result<()>) -> Self {
ChunkSink {
on_chunk,
err: None,
}
}
fn into_result(self, r: Result<Response>) -> Result<Response> {
match (self.err, r) {
(Some(e), _) => Err(e),
(None, r) => r,
}
}
}
impl Write for ChunkSink<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match (self.on_chunk)(buf) {
Ok(()) => Ok(buf.len()),
Err(e) => {
self.err = Some(e);
Err(io::Error::other("chunk callback aborted the transfer"))
}
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Response {
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
fn charset(&self) -> Option<String> {
let ct = self.header("content-type")?;
for param in ct.split(';').skip(1) {
let param = param.trim();
let (k, v) = param.split_once('=')?;
if k.trim().eq_ignore_ascii_case("charset") {
return Some(v.trim().trim_matches('"').to_ascii_lowercase());
}
}
None
}
pub fn text(&self) -> Result<String> {
match self.charset().as_deref() {
None | Some("utf-8") | Some("utf8") | Some("us-ascii") | Some("ascii") => {
Ok(String::from_utf8_lossy(&self.body).into_owned())
}
Some("iso-8859-1") | Some("iso8859-1") | Some("latin1") => {
Ok(self.body.iter().map(|&b| b as char).collect())
}
Some(other) => Err(Error::Decode(format!(
"unsupported Content-Type charset {other:?}; \
use Response::body for the raw {} bytes",
self.body.len()
))),
}
}
#[cfg(feature = "json")]
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
serde_json::from_slice(&self.body).map_err(|e| Error::Decode(format!("json: {e}")))
}
pub fn into_reader(self) -> io::Cursor<Vec<u8>> {
io::Cursor::new(self.body)
}
pub fn error_for_status(self) -> Result<Self> {
if self.status >= 400 {
Err(Error::Status {
code: self.status,
reason: self.reason.clone(),
})
} else {
Ok(self)
}
}
}
pub struct BodyReader {
head: ResponseHead,
inner: BodyInner,
_cancel: Option<crate::cancel::CancelGuard>,
}
enum BodyInner {
Buffered(io::Cursor<Vec<u8>>),
Length(LengthBody),
Eof(io::Take<BufReader<Box<dyn Rw>>>),
}
impl BodyReader {
pub fn head(&self) -> &ResponseHead {
&self.head
}
pub fn status(&self) -> u16 {
self.head.status
}
pub fn header(&self, name: &str) -> Option<&str> {
self.head.header(name)
}
}
impl Read for BodyReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match &mut self.inner {
BodyInner::Buffered(c) => c.read(buf),
BodyInner::Length(l) => l.read(buf),
BodyInner::Eof(t) => t.read(buf),
}
}
}
struct LengthBody {
src: BufReader<Box<dyn Rw>>,
remaining: u64,
}
impl Read for LengthBody {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.remaining == 0 {
return Ok(0);
}
let cap = self.remaining.min(buf.len() as u64) as usize;
let n = self.src.read(&mut buf[..cap])?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"response body truncated before Content-Length",
));
}
self.remaining -= n as u64;
Ok(n)
}
}
type OpenedBody = (
Head,
BufReader<Box<dyn Rw>>,
Option<crate::cancel::CancelGuard>,
);
fn open_h1_for_reader(req: &Request, trace: &mut dyn Write) -> Result<OpenedBody> {
let (tcp, guard, _namelookup) = tcp_connect_cancellable(req, trace)?;
let stream: Box<dyn Rw> = if req.url.scheme == "https" {
let opts = tls_opts_from(req, &[])?;
let tls = crate::tls::connect_over_tls(tcp, &req.url.host, opts)?;
write_tls_info(&tls, trace);
Box::new(tls)
} else {
Box::new(tcp)
};
let mut bufrd = BufReader::new(stream);
write_request(bufrd.get_mut(), req, via_plain_http_proxy(req), trace)?;
let head = read_head(&mut bufrd, trace)?;
Ok((head, bufrd, guard))
}
fn build_body_reader(
req: &Request,
head: Head,
mut bufrd: BufReader<Box<dyn Rw>>,
cancel: Option<crate::cancel::CancelGuard>,
) -> Result<BodyReader> {
let rhead = ResponseHead {
status: head.status,
reason: head.reason.clone(),
version: head.version.clone(),
headers: head.headers.clone(),
};
let has_body = !(req.method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&head.status)
|| head.status == 204
|| head.status == 304);
if !has_body {
return Ok(BodyReader {
head: rhead,
inner: BodyInner::Buffered(io::Cursor::new(Vec::new())),
_cancel: None,
});
}
let has_te = head
.headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("transfer-encoding"));
let has_cl = head
.headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("content-length"));
if has_te && has_cl {
return Err(Error::BadResponse(
"both Transfer-Encoding and Content-Length present".into(),
));
}
let chunked = head.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
if chunked {
let body = read_body(
&mut bufrd,
&head.headers,
&head.version,
head.status,
&req.method,
)?;
return Ok(BodyReader {
head: rhead,
inner: BodyInner::Buffered(io::Cursor::new(body)),
_cancel: None,
});
}
match parse_content_length(&head.headers)? {
Some(len) => {
if len > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse(format!("body too large: {len}")));
}
Ok(BodyReader {
head: rhead,
inner: BodyInner::Length(LengthBody {
src: bufrd,
remaining: len,
}),
_cancel: cancel,
})
}
None => Ok(BodyReader {
head: rhead,
inner: BodyInner::Eof(bufrd.take(MAX_BODY_BYTES as u64)),
_cancel: cancel,
}),
}
}
fn buffered_body_reader(mut req: Request, trace: &mut dyn Write) -> Result<BodyReader> {
req.decompress = false;
let resp = req.send_to(trace, None)?;
let head = ResponseHead {
status: resp.status,
reason: resp.reason.clone(),
version: resp.version.clone(),
headers: resp.headers.clone(),
};
Ok(BodyReader {
head,
inner: BodyInner::Buffered(io::Cursor::new(resp.body)),
_cancel: None,
})
}
fn redirect_request(req: Request, head: &Head, location: &str) -> Result<Request> {
let mut next_url = crate::url::resolve(&req.url, location)?;
if next_url.scheme != "http" && next_url.scheme != "https" {
return Err(Error::UnsupportedScheme(next_url.scheme.clone()));
}
next_url.set_idn(req.idn)?;
let host_changed = next_url.host != req.url.host
|| next_url.port != req.url.port
|| next_url.scheme != req.url.scheme;
let prev_method = req.method.clone();
let prev_url = url_to_string(&req.url);
let status = head.status;
let mut next = req;
next.url = next_url;
if next.auto_referer {
next.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("referer"));
next.headers.push(("Referer".to_string(), prev_url));
}
if host_changed && !next.redirect_trusted {
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("authorization") && !k.eq_ignore_ascii_case("cookie")
});
next.basic_auth = None;
}
let keep_post = if (301..=303).contains(&status) {
next.keep_post[(status - 301) as usize]
} else {
false
};
if (301..=303).contains(&status)
&& !keep_post
&& !prev_method.eq_ignore_ascii_case("GET")
&& !prev_method.eq_ignore_ascii_case("HEAD")
{
next.method = "GET".to_string();
next.body = Vec::new();
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("content-type")
&& !k.eq_ignore_ascii_case("content-length")
&& !k.eq_ignore_ascii_case("transfer-encoding")
});
}
Ok(next)
}
fn send_plain(req: Request, trace: &mut dyn Write) -> Result<Response> {
match req.http_version_pref {
HttpVersionPref::Http3Only => {
return Err(Error::UnsupportedScheme(
"http/3 requires https://, not http://".into(),
));
}
HttpVersionPref::Http3 => {
let _ = writeln!(
trace,
"* HTTP/3 requested but URL is http://; using HTTP/1.1 (h3 needs https)"
);
}
_ => {}
}
if !req.connector.is_direct() {
return send_plain_via_connector(req, trace);
}
let direct = req.proxy.is_none() || proxy_bypassed(&req);
if direct {
if let Some(bufrd) = pool_checkout_plain(&req) {
let _ = writeln!(trace, "* Reusing existing connection from pool");
match perform_on_pooled_plain(bufrd, &req, trace) {
Ok(resp) => return Ok(resp),
Err(PooledError::Stale(why)) => {
let _ = writeln!(trace, "* Pooled connection unusable ({why}); reconnecting");
}
Err(PooledError::Hard(e)) => return Err(e),
}
}
}
send_plain_fresh(req, direct, trace)
}
fn send_plain_fresh(req: Request, may_pool: bool, trace: &mut dyn Write) -> Result<Response> {
let start = std::time::Instant::now();
let (stream, _cancel_guard, namelookup) = tcp_connect_cancellable(&req, trace)?;
let connect = start.elapsed();
let mut bufrd = BufReader::new(stream);
write_request(bufrd.get_mut(), &req, via_plain_http_proxy(&req), trace)?;
let mut resp =
read_response_timed(&mut bufrd, &req.method, req.decompress, Some(start), trace)?;
resp.timing.namelookup = namelookup;
resp.timing.connect = Some(connect);
resp.timing.pretransfer = Some(connect); finalize_plain(bufrd, &req, &resp, may_pool, trace);
Ok(resp)
}
fn connector_connect(req: &Request, trace: &mut dyn Write) -> Result<Box<dyn NetStream>> {
let _ = writeln!(
trace,
"* Connecting to {}:{} via {:?}",
req.url.host, req.url.port, req.connector
);
let stream = req
.connector
.connect(&req.url.host, req.url.port, req.connect_timeout)?;
stream.set_read_timeout(req.read_timeout)?;
stream.set_write_timeout(req.read_timeout)?;
Ok(stream)
}
fn send_plain_via_connector(req: Request, trace: &mut dyn Write) -> Result<Response> {
let stream = connector_connect(&req, trace)?;
let mut bufrd = BufReader::new(stream);
write_request(bufrd.get_mut(), &req, false, trace)?;
let resp = read_response(&mut bufrd, &req.method, req.decompress, trace)?;
Ok(resp)
}
fn send_https_via_connector(req: Request, trace: &mut dyn Write) -> Result<Response> {
match req.http_version_pref {
HttpVersionPref::Http2Only => {
return Err(Error::UnsupportedScheme(
"HTTP/2 (--http2) over a custom connector or SOCKS/HTTPS proxy is not supported"
.into(),
));
}
HttpVersionPref::Http3Only => {
return Err(Error::UnsupportedScheme(
"HTTP/3 (--http3-only) over a custom connector or SOCKS/HTTPS proxy is not supported"
.into(),
));
}
_ => {}
}
let stream = connector_connect(&req, trace)?;
let opts = tls_opts_from(&req, &[])?;
let tls = crate::tls::connect_over_tls(stream, &req.url.host, opts)?;
write_tls_info(&tls, trace);
let mut bufrd = BufReader::new(tls);
write_request(bufrd.get_mut(), &req, false, trace)?;
let resp = read_response(&mut bufrd, &req.method, req.decompress, trace)?;
Ok(resp)
}
fn perform_on_pooled_plain(
mut bufrd: BufReader<TcpStream>,
req: &Request,
trace: &mut dyn Write,
) -> std::result::Result<Response, PooledError> {
if let Err(e) = write_request(bufrd.get_mut(), req, via_plain_http_proxy(req), trace) {
return Err(stale_or_hard(e));
}
let resp = match read_response(&mut bufrd, &req.method, req.decompress, trace) {
Ok(r) => r,
Err(e) => return Err(stale_or_hard(e)),
};
finalize_plain(bufrd, req, &resp, true, trace);
Ok(resp)
}
fn finalize_plain(
bufrd: BufReader<TcpStream>,
req: &Request,
resp: &Response,
may_pool: bool,
trace: &mut dyn Write,
) {
if may_pool && response_is_reusable(&req.method, resp) {
crate::pool::plain()
.lock()
.unwrap_or_else(|e| e.into_inner())
.release(pool_key_for(req), bufrd);
let _ = writeln!(trace, "* Connection kept alive (pooled)");
} else {
let _ = writeln!(trace, "* Connection closed");
}
}
enum PooledError {
Stale(String),
Hard(Error),
}
fn stale_or_hard(e: Error) -> PooledError {
match &e {
Error::UnexpectedEof => PooledError::Stale("connection closed by peer".into()),
Error::Io(io_err) => match io_err.kind() {
io::ErrorKind::UnexpectedEof
| io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotConnected => PooledError::Stale(io_err.to_string()),
_ => PooledError::Hard(e),
},
_ => PooledError::Hard(e),
}
}
fn effective_dial_target(
connect_to: &[(String, u16, String, u16)],
resolve: &[(String, u16, std::net::IpAddr)],
host: &str,
port: u16,
) -> Option<(String, u16)> {
let (dial_host, dial_port) = apply_connect_to(connect_to, host, port);
let remapped = dial_host != host || dial_port != port;
let pinned_ip = resolve
.iter()
.find(|(h, p, _)| *p == dial_port && h.eq_ignore_ascii_case(&dial_host))
.map(|(_, _, ip)| *ip);
match pinned_ip {
Some(ip) => Some((ip.to_string(), dial_port)),
None if remapped => Some((dial_host, dial_port)),
None => None,
}
}
pub(crate) fn pool_key_for(req: &Request) -> crate::pool::Key {
let u = &req.url;
crate::pool::Key {
scheme: u.scheme.clone(),
host: u.host.clone(),
port: u.port,
effective_target: effective_dial_target(&req.connect_to, &req.resolve, &u.host, u.port),
partition: req.partition_key.clone(),
}
}
fn pool_checkout_plain(req: &Request) -> Option<BufReader<TcpStream>> {
crate::pool::plain()
.lock()
.unwrap_or_else(|e| e.into_inner())
.checkout(&pool_key_for(req))
}
pub(crate) fn tls_pool_eligible(req: &Request) -> bool {
req.verify_tls
&& req.ca_bundle.is_none()
&& req.ca_path.is_none()
&& req.client_cert.is_none()
&& req.pinned_pubkey.is_none()
&& req.crl_file.is_none()
&& req.ciphers.is_none()
&& req.tls13_ciphers.is_none()
}
fn pool_checkout_tls(req: &Request) -> Option<BufReader<crate::tls::TlsStream<TcpStream>>> {
if !tls_pool_eligible(req) {
return None;
}
crate::pool::tls()
.lock()
.unwrap_or_else(|e| e.into_inner())
.checkout(&pool_key_for(req))
}
fn response_is_reusable(method: &str, resp: &Response) -> bool {
let conn_close = resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("connection")
&& v.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("close"))
});
if conn_close {
return false;
}
let has_framing = resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("content-length")
|| (k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked"))
});
let no_body_allowed = method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&resp.status)
|| resp.status == 204
|| resp.status == 304;
if !has_framing && !no_body_allowed {
return false;
}
if resp.version == "HTTP/1.1" {
true
} else {
resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("connection")
&& v.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("keep-alive"))
})
}
}
pub(crate) fn via_plain_http_proxy(req: &Request) -> bool {
if req.url.scheme != "http" {
return false;
}
match &req.proxy {
Some(_) => !proxy_bypassed(req),
None => false,
}
}
pub(crate) fn proxy_bypassed(req: &Request) -> bool {
if req.no_proxy.iter().any(|e| e.trim() == "*") {
return true;
}
let h = req.url.host.to_ascii_lowercase();
req.no_proxy.iter().any(|e| {
let e = e.trim().trim_start_matches('.').to_ascii_lowercase();
if e.is_empty() {
return false;
}
h == e || h.ends_with(&format!(".{e}"))
})
}
pub(crate) fn effective_method(req: &Request) -> String {
if req.keep_method_case {
req.method.clone()
} else {
req.method.to_ascii_uppercase()
}
}
pub(crate) fn effective_basic_auth(req: &Request) -> Option<String> {
let (user, pass) = match &req.basic_auth {
Some((u, p)) => (u.clone(), p.clone()),
None => {
let info = req.url.userinfo.as_deref()?;
match info.split_once(':') {
Some((u, p)) => (u.to_string(), p.to_string()),
None => (info.to_string(), String::new()),
}
}
};
if user.is_empty() && pass.is_empty() {
return None;
}
let combined = format!("{user}:{pass}");
Some(crate::websocket::base64_encode(combined.as_bytes()))
}
pub(crate) fn tls_opts_from(req: &Request, alpn: &[&[u8]]) -> Result<crate::tls::TlsOpts> {
let mut opts = crate::tls::TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
opts.verify = req.verify_tls;
opts.min_version = req.tls_min;
opts.max_version = req.tls_max;
if let Some(path) = &req.ca_bundle {
opts.roots = Some(crate::tls::load_roots_from_file(path)?);
}
if let Some(dir) = &req.ca_path {
opts.roots = Some(crate::tls::load_roots_from_dir(opts.roots.take(), dir)?);
}
if let Some(cert_path) = &req.client_cert {
opts.client_cert = Some(std::fs::read(cert_path).map_err(Error::Io)?);
opts.cert_is_der = req.cert_is_der;
opts.key_is_der = req.key_is_der;
opts.client_key_pass = req.client_key_pass.clone();
if let Some(key_path) = &req.client_key {
opts.client_key = Some(std::fs::read(key_path).map_err(Error::Io)?);
}
}
if let Some(spec) = &req.pinned_pubkey {
opts.pinned_spki_sha256 = crate::tls::parse_pinned_pubkey(spec)?;
}
if let Some(path) = &req.crl_file {
opts.crl_pem = Some(std::fs::read(path).map_err(Error::Io)?);
}
if let Some(spec) = &req.ciphers {
opts.cipher_suites
.extend(crate::tls::cipher_names_to_ids(spec)?);
}
if let Some(spec) = &req.tls13_ciphers {
opts.cipher_suites
.extend(crate::tls::cipher_names_to_ids(spec)?);
}
opts.verify_callback = req.tls_verify_callback.clone();
Ok(opts)
}
pub(crate) fn h3_should_fall_back(e: &Error) -> bool {
match e {
Error::Io(_) => true,
Error::UnsupportedScheme(_) | Error::InvalidUrl(_) => true,
Error::BadResponse(m) => {
m.starts_with("http3: connection closed")
|| m.starts_with("http3: peer closed")
|| m.starts_with("http3: build client")
|| m.starts_with("http3: open_bidi")
|| m.starts_with("http3: open_uni")
|| m.starts_with("http3: feed")
|| m.starts_with("http3: stream read")
|| m.starts_with("http3: stream write")
|| m.starts_with("http3: stream finish")
}
_ => false,
}
}
pub fn send_multiplexed(reqs: Vec<Request>) -> Vec<Result<Response>> {
send_multiplexed_traced(reqs, &mut std::io::sink())
}
pub fn send_multiplexed_traced(
mut reqs: Vec<Request>,
trace: &mut dyn Write,
) -> Vec<Result<Response>> {
for req in &mut reqs {
let _ = req.url.set_idn(req.idn);
}
let mut idx: Vec<usize> = (0..reqs.len()).collect();
idx.sort_by_key(|&i| reqs[i].priority.rank());
let already_in_order = idx.iter().enumerate().all(|(k, &i)| k == i);
if already_in_order {
return crate::http2::send_multiplexed(reqs, trace);
}
let mut slots: Vec<Option<Request>> = reqs.into_iter().map(Some).collect();
let ordered: Vec<Request> = idx.iter().map(|&i| slots[i].take().unwrap()).collect();
let results = crate::http2::send_multiplexed(ordered, trace);
let mut out: Vec<Option<Result<Response>>> = (0..idx.len()).map(|_| None).collect();
for (k, r) in results.into_iter().enumerate() {
out[idx[k]] = Some(r);
}
out.into_iter()
.map(|o| o.expect("each slot filled"))
.collect()
}
fn send_https(req: Request, trace: &mut dyn Write) -> Result<Response> {
if !req.connector.is_direct() {
return send_https_via_connector(req, trace);
}
match req.http_version_pref {
HttpVersionPref::Http3Only => {
let _ = writeln!(trace, "* Trying HTTP/3 (QUIC), required (--http3-only)");
return crate::http3::send(req, trace);
}
HttpVersionPref::Http3 => {
let _ = writeln!(trace, "* Trying HTTP/3 (QUIC)...");
match crate::http3::send(req.clone(), trace) {
Ok(resp) => return Ok(resp),
Err(e) if h3_should_fall_back(&e) => {
let _ = writeln!(trace, "* HTTP/3 failed ({e}), falling back to HTTP/2/1.1");
}
Err(e) => return Err(e),
}
}
_ => {}
}
match req.http_version_pref {
HttpVersionPref::Http2Only => {
let _ = writeln!(trace, "* HTTP/2 required (--http2)");
return crate::http2::send(req, trace);
}
HttpVersionPref::Auto | HttpVersionPref::Http3 => {
let _ = writeln!(trace, "* Trying HTTP/2 via ALPN (h2)");
match crate::http2::send(req.clone(), trace) {
Ok(resp) => return Ok(resp),
Err(Error::H2NotNegotiated) => {
let _ = writeln!(
trace,
"* ALPN: server did not select h2, falling back to HTTP/1.1"
);
}
Err(e) => return Err(e),
}
}
HttpVersionPref::Http11Only => {
let _ = writeln!(trace, "* HTTP/1.1 forced (--http1.1)");
}
HttpVersionPref::Http3Only => unreachable!("Http3Only handled above"),
}
let direct = req.proxy.is_none() || proxy_bypassed(&req);
if direct {
if let Some(bufrd) = pool_checkout_tls(&req) {
let _ = writeln!(trace, "* Reusing existing connection from pool");
match perform_on_pooled_tls(bufrd, &req, trace) {
Ok(resp) => return Ok(resp),
Err(PooledError::Stale(why)) => {
let _ = writeln!(trace, "* Pooled connection unusable ({why}); reconnecting");
}
Err(PooledError::Hard(e)) => return Err(e),
}
}
}
send_https_fresh(req, direct, trace)
}
fn send_https_fresh(req: Request, may_pool: bool, trace: &mut dyn Write) -> Result<Response> {
let start = std::time::Instant::now();
let (tcp, _cancel_guard, namelookup) = tcp_connect_cancellable(&req, trace)?;
let connect = start.elapsed();
if let Some(p) = req
.proxy
.as_ref()
.filter(|_| !proxy_bypassed(&req) && req.url.scheme == "https")
{
connect_tunnel(&tcp, &req.url, p, trace)?;
}
let opts = tls_opts_from(&req, &[])?;
let tls = crate::tls::connect_over_tls(tcp, &req.url.host, opts)?;
let appconnect = start.elapsed();
write_tls_info(&tls, trace);
let tls_info = tls_info_from(&tls);
let mut bufrd = BufReader::new(tls);
write_request(bufrd.get_mut(), &req, false, trace)?;
let mut resp =
read_response_timed(&mut bufrd, &req.method, req.decompress, Some(start), trace)?;
resp.timing.namelookup = namelookup;
resp.timing.connect = Some(connect);
resp.timing.appconnect = Some(appconnect);
resp.timing.pretransfer = Some(appconnect); resp.tls = Some(tls_info);
finalize_tls(bufrd, &req, &resp, may_pool, trace);
Ok(resp)
}
fn perform_on_pooled_tls(
mut bufrd: BufReader<crate::tls::TlsStream<TcpStream>>,
req: &Request,
trace: &mut dyn Write,
) -> std::result::Result<Response, PooledError> {
if let Err(e) = write_request(bufrd.get_mut(), req, false, trace) {
return Err(stale_or_hard(e));
}
let mut resp = match read_response(&mut bufrd, &req.method, req.decompress, trace) {
Ok(r) => r,
Err(e) => return Err(stale_or_hard(e)),
};
resp.tls = Some(tls_info_from(bufrd.get_ref()));
finalize_tls(bufrd, req, &resp, true, trace);
Ok(resp)
}
fn finalize_tls(
bufrd: BufReader<crate::tls::TlsStream<TcpStream>>,
req: &Request,
resp: &Response,
may_pool: bool,
trace: &mut dyn Write,
) {
if may_pool && tls_pool_eligible(req) && response_is_reusable(&req.method, resp) {
crate::pool::tls()
.lock()
.unwrap_or_else(|e| e.into_inner())
.release(pool_key_for(req), bufrd);
let _ = writeln!(trace, "* Connection kept alive (pooled)");
} else {
let _ = writeln!(trace, "* Connection closed");
}
}
fn apply_connect_to(rules: &[(String, u16, String, u16)], host: &str, port: u16) -> (String, u16) {
for (fh, fp, th, tp) in rules {
let host_ok = fh.is_empty() || fh.eq_ignore_ascii_case(host);
let port_ok = *fp == 0 || *fp == port;
if host_ok && port_ok {
let new_host = if th.is_empty() { host } else { th.as_str() };
let new_port = if *tp == 0 { port } else { *tp };
return (new_host.to_string(), new_port);
}
}
(host.to_string(), port)
}
fn resolve_target(
host: &str,
port: u16,
req: &Request,
) -> Result<(std::net::SocketAddr, Option<Duration>)> {
if let Some((_, _, ip)) = req
.resolve
.iter()
.find(|(h, p, _)| *p == port && h.eq_ignore_ascii_case(host))
{
return Ok((std::net::SocketAddr::new(*ip, port), None));
}
let started = std::time::Instant::now();
let addrs = req.resolver.resolve(host, port)?;
let dns = started.elapsed();
let chosen = match req.ip_family {
Some(IpFamily::V4) => addrs
.into_iter()
.find(|a| a.is_ipv4())
.ok_or_else(|| Error::InvalidUrl(format!("{host}: no IPv4 address"))),
Some(IpFamily::V6) => addrs
.into_iter()
.find(|a| a.is_ipv6())
.ok_or_else(|| Error::InvalidUrl(format!("{host}: no IPv6 address"))),
None => addrs
.into_iter()
.next()
.ok_or_else(|| Error::InvalidUrl(host.to_string())),
}?;
Ok((chosen, Some(dns)))
}
pub(crate) fn tcp_connect_cancellable(
req: &Request,
trace: &mut dyn Write,
) -> Result<(
TcpStream,
Option<crate::cancel::CancelGuard>,
Option<Duration>,
)> {
let (stream, namelookup) = tcp_connect_inner(req, trace)?;
let guard = match &req.cancel {
Some(token) => {
if let Ok(shut) = stream.try_clone() {
Some(token.register(Box::new(move || {
let _ = shut.shutdown(std::net::Shutdown::Both);
})))
} else {
None
}
}
None => None,
};
Ok((stream, guard, namelookup))
}
fn tcp_connect_inner(
req: &Request,
trace: &mut dyn Write,
) -> Result<(TcpStream, Option<Duration>)> {
let proxy = req.proxy.as_ref().filter(|_| !proxy_bypassed(req));
let (target_host, target_port, via_proxy_label) = match proxy {
Some(p) => (p.host.as_str(), p.port, true),
None => (req.url.host.as_str(), req.url.port, false),
};
let (target_host, target_port) = if via_proxy_label {
(target_host.to_string(), target_port)
} else {
apply_connect_to(&req.connect_to, target_host, target_port)
};
let (first, namelookup) = resolve_target(&target_host, target_port, req)?;
let _ = writeln!(trace, "* Trying {first}...");
let stream = match req.connect_timeout {
Some(t) => TcpStream::connect_timeout(&first, t)?,
None => TcpStream::connect(first)?,
};
let peer = stream.peer_addr().unwrap_or(first);
if via_proxy_label {
let _ = writeln!(
trace,
"* Connected to proxy {} ({}) port {}",
target_host,
peer.ip(),
peer.port()
);
} else {
let _ = writeln!(
trace,
"* Connected to {} ({}) port {}",
req.url.host,
peer.ip(),
peer.port()
);
}
stream.set_read_timeout(req.read_timeout)?;
stream.set_write_timeout(req.read_timeout)?;
Ok((stream, namelookup))
}
pub(crate) fn connect_tunnel<S: Read + Write>(
mut stream: S,
target: &Url,
proxy: &ProxyConfig,
trace: &mut dyn Write,
) -> Result<()> {
let host_port = format!("{}:{}", target.host, target.port);
let mut buf = Vec::with_capacity(256);
write!(&mut buf, "CONNECT {host_port} HTTP/1.1\r\n")?;
write!(&mut buf, "Host: {host_port}\r\n")?;
write!(&mut buf, "User-Agent: {DEFAULT_USER_AGENT}\r\n")?;
write!(&mut buf, "Proxy-Connection: Keep-Alive\r\n")?;
if let Some((user, pass)) = &proxy.auth {
let combined = format!("{user}:{pass}");
let creds = crate::websocket::base64_encode(combined.as_bytes());
write!(&mut buf, "Proxy-Authorization: Basic {creds}\r\n")?;
}
write!(&mut buf, "\r\n")?;
let head = String::from_utf8_lossy(&buf);
let head_no_final_crlf = head.strip_suffix("\r\n").unwrap_or(&head);
for line in head_no_final_crlf.split("\r\n") {
let _ = writeln!(trace, "> {line}");
}
stream.write_all(&buf)?;
stream.flush()?;
let mut line_buf: Vec<u8> = Vec::with_capacity(128);
let mut byte = [0u8; 1];
let mut status_line: Option<String> = None;
let mut total = 0usize;
loop {
if total > MAX_HEADER_BYTES {
return Err(Error::BadResponse(
"CONNECT response headers exceed 64 KiB".into(),
));
}
let n = stream.read(&mut byte)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
total += 1;
line_buf.push(byte[0]);
if byte[0] == b'\n' {
let trimmed_owned = String::from_utf8_lossy(
line_buf
.strip_suffix(b"\n")
.unwrap_or(&line_buf)
.strip_suffix(b"\r")
.unwrap_or(line_buf.strip_suffix(b"\n").unwrap_or(&line_buf)),
)
.into_owned();
let _ = writeln!(trace, "< {trimmed_owned}");
if status_line.is_none() {
status_line = Some(trimmed_owned.clone());
}
if trimmed_owned.is_empty() {
break;
}
line_buf.clear();
}
}
let status = status_line.ok_or_else(|| Error::BadResponse("CONNECT: no status line".into()))?;
let parts: Vec<&str> = status.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(Error::BadResponse(format!(
"CONNECT: malformed status line {status:?}"
)));
}
let code: u16 = parts[1]
.parse()
.map_err(|_| Error::BadResponse(format!("CONNECT: bad status {:?}", parts[1])))?;
if !(200..300).contains(&code) {
return Err(Error::BadResponse(format!(
"CONNECT to {host_port} failed: {status}"
)));
}
let _ = writeln!(trace, "* CONNECT tunnel established to {host_port}");
Ok(())
}
pub(crate) fn tls_info_from<S: Read + Write>(tls: &crate::tls::TlsStream<S>) -> TlsInfo {
TlsInfo {
version: tls.negotiated_version(),
cipher_suite: tls.negotiated_cipher_suite(),
alpn: tls.alpn_selected().map(|p| p.to_vec()),
peer_certificates: tls.peer_certificates().to_vec(),
}
}
pub(crate) fn write_tls_info<S: Read + Write>(
tls: &crate::tls::TlsStream<S>,
trace: &mut dyn Write,
) {
if let Some(v) = tls.negotiated_version() {
let _ = writeln!(trace, "* SSL connection using {v:?}");
}
match tls.alpn_selected() {
Some(p) => {
let _ = writeln!(
trace,
"* ALPN: server accepted {}",
String::from_utf8_lossy(p)
);
}
None => {
let _ = writeln!(trace, "* ALPN: no protocol negotiated");
}
}
let certs = tls.peer_certificates();
let _ = writeln!(trace, "* Server certificate chain: {} cert(s)", certs.len());
for (i, der) in certs.iter().enumerate() {
match purecrypto::x509::Certificate::from_der(der.clone()) {
Ok(cert) => {
let subject = cert
.subject()
.ok()
.and_then(|d| d.common_name)
.unwrap_or_else(|| "?".into());
let issuer = cert
.issuer()
.ok()
.and_then(|d| d.common_name)
.unwrap_or_else(|| "?".into());
let _ = writeln!(trace, "* [{i}] subject CN: {subject}");
let _ = writeln!(trace, "* issuer CN: {issuer}");
if let Ok(v) = cert.validity() {
let _ = writeln!(
trace,
"* valid: {} -> {}",
v.not_before.as_str(),
v.not_after.as_str()
);
}
}
Err(_) => {
let _ = writeln!(trace, "* [{i}] (DER unparseable, {} bytes)", der.len());
}
}
}
}
fn is_valid_header_name(name: &str) -> bool {
!name.is_empty()
&& name.bytes().all(|b| {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
)
})
}
fn header_value_has_forbidden(value: &str) -> bool {
value.bytes().any(|b| b == b'\r' || b == b'\n' || b == 0)
}
fn validate_header(name: &str, value: &str) -> Result<()> {
if !is_valid_header_name(name) {
return Err(Error::BadResponse(format!("invalid header name: {name:?}")));
}
if header_value_has_forbidden(value) {
return Err(Error::BadResponse(format!(
"invalid header value for {name:?}"
)));
}
Ok(())
}
fn validate_method(method: &str) -> Result<()> {
if method.is_empty() || method.bytes().any(|b| b < 0x20 || b == 0x7f || b == b' ') {
return Err(Error::BadResponse(format!("invalid method: {method:?}")));
}
Ok(())
}
fn write_request<W: Write>(
mut w: W,
req: &Request,
absolute_form: bool,
trace: &mut dyn Write,
) -> Result<()> {
let method = effective_method(req);
validate_method(&method)?;
for (k, v) in &req.headers {
validate_header(k, v)?;
}
let host_header = if (req.url.scheme == "http" && req.url.port == 80)
|| (req.url.scheme == "https" && req.url.port == 443)
{
req.url.host.clone()
} else {
format!("{}:{}", req.url.host, req.url.port)
};
let caller_set_host = req
.headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("host"));
let mut buf = Vec::with_capacity(256);
if absolute_form {
let target = if (req.url.scheme == "http" && req.url.port == 80)
|| (req.url.scheme == "https" && req.url.port == 443)
{
format!("{}://{}{}", req.url.scheme, req.url.host, req.url.path)
} else {
format!(
"{}://{}:{}{}",
req.url.scheme, req.url.host, req.url.port, req.url.path
)
};
write!(&mut buf, "{method} {target} HTTP/1.1\r\n")?;
} else {
write!(&mut buf, "{method} {} HTTP/1.1\r\n", req.url.path)?;
}
if !(req.strict_headers && caller_set_host) {
write!(&mut buf, "Host: {host_header}\r\n")?;
}
if absolute_form {
if let Some(p) = &req.proxy {
if let Some((user, pass)) = &p.auth {
let combined = format!("{user}:{pass}");
let creds = crate::websocket::base64_encode(combined.as_bytes());
write!(&mut buf, "Proxy-Authorization: Basic {creds}\r\n")?;
}
}
}
let mut have_ua = false;
let mut have_accept = false;
let mut have_accept_enc = false;
let mut have_clen = false;
let mut have_auth = false;
for (k, v) in &req.headers {
if k.eq_ignore_ascii_case("user-agent") {
have_ua = true;
}
if k.eq_ignore_ascii_case("accept") {
have_accept = true;
}
if k.eq_ignore_ascii_case("accept-encoding") {
have_accept_enc = true;
}
if k.eq_ignore_ascii_case("content-length") {
have_clen = true;
}
if k.eq_ignore_ascii_case("authorization") {
have_auth = true;
}
write!(&mut buf, "{k}: {v}\r\n")?;
}
if !req.strict_headers {
if !have_auth {
if let Some(creds) = effective_basic_auth(req) {
write!(&mut buf, "Authorization: Basic {creds}\r\n")?;
}
}
if !have_ua {
write!(&mut buf, "User-Agent: {DEFAULT_USER_AGENT}\r\n")?;
}
if !have_accept {
write!(&mut buf, "Accept: */*\r\n")?;
}
if !have_accept_enc {
write!(&mut buf, "Accept-Encoding: gzip, deflate\r\n")?;
}
}
if !req.body.is_empty() && !have_clen {
write!(&mut buf, "Content-Length: {}\r\n", req.body.len())?;
}
write!(&mut buf, "\r\n")?;
let head = String::from_utf8_lossy(&buf);
let head_no_final_crlf = head.strip_suffix("\r\n").unwrap_or(&head);
for line in head_no_final_crlf.split("\r\n") {
let _ = writeln!(trace, "> {line}");
}
w.write_all(&buf)?;
if !req.body.is_empty() {
let _ = writeln!(trace, "* uploading {} body bytes", req.body.len());
w.write_all(&req.body)?;
}
w.flush()?;
Ok(())
}
fn read_line_capped<R: BufRead>(r: &mut R, buf: &mut String, max: usize) -> Result<usize> {
let mut raw: Vec<u8> = Vec::new();
loop {
let remaining = max.saturating_sub(raw.len());
let n = r
.by_ref()
.take(remaining as u64 + 1)
.read_until(b'\n', &mut raw)?;
if n == 0 {
break;
}
if raw.last() == Some(&b'\n') {
break;
}
if raw.len() > max {
return Err(Error::BadResponse("response line exceeds 64 KiB".into()));
}
}
if raw.is_empty() {
return Ok(0);
}
let read = raw.len();
buf.push_str(&String::from_utf8_lossy(&raw));
Ok(read)
}
struct Head {
version: String,
status: u16,
reason: String,
headers: Vec<(String, String)>,
}
fn read_head<R: Read>(r: &mut BufReader<R>, trace: &mut dyn Write) -> Result<Head> {
let mut status_line = String::new();
let n = read_line_capped(r, &mut status_line, MAX_HEADER_BYTES)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
let trimmed_status = status_line.trim_end_matches(['\r', '\n']);
let _ = writeln!(trace, "< {trimmed_status}");
let (version, status, reason) = parse_status_line(trimmed_status)?;
let mut headers: Vec<(String, String)> = Vec::new();
let mut header_bytes = 0usize;
loop {
let mut line = String::new();
let n = read_line_capped(r, &mut line, MAX_HEADER_BYTES)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
header_bytes += n;
if header_bytes > MAX_HEADER_BYTES {
return Err(Error::BadResponse("headers exceed 64 KiB".into()));
}
let trimmed = line.trim_end_matches(['\r', '\n']);
let _ = writeln!(trace, "< {trimmed}");
if trimmed.is_empty() {
break;
}
let (k, v) = trimmed
.split_once(':')
.ok_or_else(|| Error::BadResponse(format!("malformed header line: {trimmed:?}")))?;
headers.push((k.trim().to_string(), v.trim().to_string()));
}
Ok(Head {
version,
status,
reason,
headers,
})
}
fn read_response<R: Read>(
r: &mut BufReader<R>,
method: &str,
decompress: bool,
trace: &mut dyn Write,
) -> Result<Response>
where
BufReader<R>: TruncationAware,
{
read_response_timed(r, method, decompress, None, trace)
}
fn read_response_timed<R: Read>(
r: &mut BufReader<R>,
method: &str,
decompress: bool,
start: Option<std::time::Instant>,
trace: &mut dyn Write,
) -> Result<Response>
where
BufReader<R>: TruncationAware,
{
let Head {
version,
status,
reason,
headers,
} = read_head(r, trace)?;
let starttransfer = start.map(|s| s.elapsed());
let body = read_body(r, &headers, &version, status, method)?;
let wire_len = body.len();
let _ = writeln!(trace, "* Received {wire_len} body bytes");
let (headers, body) = maybe_decode_body(headers, body, decompress, trace)?;
Ok(Response {
status,
reason,
version,
headers,
body,
timing: Timing {
starttransfer,
..Default::default()
},
final_url: String::new(),
tls: None,
})
}
pub(crate) type HeadersAndBody = (Vec<(String, String)>, Vec<u8>);
pub(crate) fn maybe_decode_body(
headers: Vec<(String, String)>,
body: Vec<u8>,
decompress: bool,
trace: &mut dyn Write,
) -> Result<HeadersAndBody> {
if !decompress {
return Ok((headers, body));
}
let Some(enc) = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
.map(|(_, v)| v.clone())
else {
return Ok((headers, body));
};
let wire_len = body.len();
let out = crate::compress::decode_body(body, &enc)?;
if out.decoded {
let _ = writeln!(
trace,
"* Decompressed body: {} -> {} bytes ({})",
wire_len,
out.body.len(),
enc
);
Ok((crate::compress::strip_after_decode(headers), out.body))
} else {
Ok((headers, out.body))
}
}
pub(crate) fn parse_status_line(line: &str) -> Result<(String, u16, String)> {
let mut parts = line.splitn(3, ' ');
let version = parts
.next()
.ok_or_else(|| Error::BadResponse(format!("missing version: {line:?}")))?
.to_string();
if !version.starts_with("HTTP/") {
return Err(Error::BadResponse(format!("not HTTP: {version}")));
}
let status: u16 = parts
.next()
.ok_or_else(|| Error::BadResponse(format!("missing status: {line:?}")))?
.parse()
.map_err(|_| Error::BadResponse(format!("bad status: {line:?}")))?;
let reason = parts.next().unwrap_or("").to_string();
Ok((version, status, reason))
}
pub(crate) fn parse_content_length(headers: &[(String, String)]) -> Result<Option<u64>> {
let mut seen: Option<u64> = None;
for (k, v) in headers {
if !k.eq_ignore_ascii_case("content-length") {
continue;
}
for part in v.split(',') {
let t = part.trim();
if t.is_empty() || !t.bytes().all(|b| b.is_ascii_digit()) {
return Err(Error::BadResponse(format!("bad Content-Length: {v:?}")));
}
let n: u64 = t
.parse()
.map_err(|_| Error::BadResponse(format!("bad Content-Length: {v:?}")))?;
match seen {
Some(prev) if prev != n => {
return Err(Error::BadResponse(
"conflicting Content-Length values".into(),
));
}
_ => seen = Some(n),
}
}
}
Ok(seen)
}
fn read_body<R: BufRead + TruncationAware>(
r: &mut R,
headers: &[(String, String)],
_version: &str,
status: u16,
method: &str,
) -> Result<Vec<u8>> {
if method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&status)
|| status == 204
|| status == 304
{
return Ok(Vec::new());
}
let has_te = headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("transfer-encoding"));
let has_cl = headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("content-length"));
if has_te && has_cl {
return Err(Error::BadResponse(
"both Transfer-Encoding and Content-Length present".into(),
));
}
let chunked = headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
if chunked {
return read_chunked(r);
}
let content_length = parse_content_length(headers)?;
let mut body = Vec::new();
match content_length {
Some(len) => {
if len > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse(format!("body too large: {len}")));
}
body.reserve(len.min(64 * 1024) as usize);
r.take(len).read_to_end(&mut body)?;
if (body.len() as u64) < len {
return Err(Error::UnexpectedEof);
}
}
None => {
r.take(MAX_BODY_BYTES as u64).read_to_end(&mut body)?;
if r.response_truncated() {
return Err(Error::UnexpectedEof);
}
}
}
Ok(body)
}
fn read_chunked<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut body = Vec::new();
loop {
let mut size_line = String::new();
let n = read_line_capped(r, &mut size_line, MAX_HEADER_BYTES)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
let size_str = size_line
.trim_end_matches(['\r', '\n'])
.split(';')
.next()
.unwrap_or("");
let s = size_str.trim();
if s.is_empty() || !s.bytes().all(|b| b.is_ascii_hexdigit()) {
return Err(Error::BadResponse(format!("bad chunk size: {size_str:?}")));
}
let size = usize::from_str_radix(s, 16)
.map_err(|_| Error::BadResponse(format!("bad chunk size: {size_str:?}")))?;
if body.len().saturating_add(size) > MAX_BODY_BYTES {
return Err(Error::BadResponse("body too large".into()));
}
if size == 0 {
let mut trailer_bytes: usize = 0;
loop {
let mut t = String::new();
let n = read_line_capped(r, &mut t, MAX_HEADER_BYTES)?;
if n == 0 || t.trim_end_matches(['\r', '\n']).is_empty() {
break;
}
trailer_bytes = trailer_bytes.saturating_add(n);
if trailer_bytes > MAX_HEADER_BYTES {
return Err(Error::BadResponse("trailer block too large".into()));
}
}
break;
}
let start = body.len();
body.resize(start + size, 0);
r.read_exact(&mut body[start..])?;
let mut crlf = [0u8; 2];
r.read_exact(&mut crlf)?;
if &crlf != b"\r\n" {
return Err(Error::BadResponse("missing CRLF after chunk".into()));
}
}
Ok(body)
}
trait Rw: Read + Write {
fn truncated(&self) -> bool {
false
}
}
impl Rw for TcpStream {}
impl Rw for crate::tls::TlsStream<TcpStream> {
fn truncated(&self) -> bool {
self.was_truncated()
}
}
pub(crate) trait TruncationAware {
fn response_truncated(&self) -> bool;
}
impl TruncationAware for TcpStream {
fn response_truncated(&self) -> bool {
false
}
}
impl TruncationAware for crate::tls::TlsStream<TcpStream> {
fn response_truncated(&self) -> bool {
self.was_truncated()
}
}
impl TruncationAware for Box<dyn NetStream> {
fn response_truncated(&self) -> bool {
false
}
}
impl TruncationAware for crate::tls::TlsStream<Box<dyn NetStream>> {
fn response_truncated(&self) -> bool {
self.was_truncated()
}
}
impl TruncationAware for Box<dyn Rw> {
fn response_truncated(&self) -> bool {
(**self).truncated()
}
}
impl<R: TruncationAware + ?Sized> TruncationAware for BufReader<R> {
fn response_truncated(&self) -> bool {
self.get_ref().response_truncated()
}
}
#[cfg(test)]
impl TruncationAware for std::io::Cursor<Vec<u8>> {
fn response_truncated(&self) -> bool {
false
}
}
fn stream_body<R: BufRead + TruncationAware, W: Write + ?Sized>(
r: &mut R,
sink: &mut W,
headers: &[(String, String)],
status: u16,
method: &str,
) -> Result<u64> {
if method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&status)
|| status == 204
|| status == 304
{
return Ok(0);
}
let chunked = headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
if chunked {
return stream_chunked(r, sink);
}
match parse_content_length(headers)? {
Some(len) => {
if len > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse(format!("body too large: {len}")));
}
let n = io::copy(&mut r.by_ref().take(len), sink)?;
if n < len {
return Err(Error::UnexpectedEof);
}
Ok(n)
}
None => {
let n = io::copy(&mut r.by_ref().take(MAX_BODY_BYTES as u64), sink)?;
if r.response_truncated() {
return Err(Error::UnexpectedEof);
}
Ok(n)
}
}
}
fn stream_chunked<R: BufRead, W: Write + ?Sized>(r: &mut R, sink: &mut W) -> Result<u64> {
let mut total: u64 = 0;
loop {
let mut size_line = String::new();
let n = read_line_capped(r, &mut size_line, MAX_HEADER_BYTES)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
let size_str = size_line
.trim_end_matches(['\r', '\n'])
.split(';')
.next()
.unwrap_or("");
let s = size_str.trim();
if s.is_empty() || !s.bytes().all(|b| b.is_ascii_hexdigit()) {
return Err(Error::BadResponse(format!("bad chunk size: {size_str:?}")));
}
let size = usize::from_str_radix(s, 16)
.map_err(|_| Error::BadResponse(format!("bad chunk size: {size_str:?}")))?;
if total.saturating_add(size as u64) > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse("body too large".into()));
}
if size == 0 {
let mut trailer_bytes: usize = 0;
loop {
let mut t = String::new();
let n = read_line_capped(r, &mut t, MAX_HEADER_BYTES)?;
if n == 0 || t.trim_end_matches(['\r', '\n']).is_empty() {
break;
}
trailer_bytes = trailer_bytes.saturating_add(n);
if trailer_bytes > MAX_HEADER_BYTES {
return Err(Error::BadResponse("trailer block too large".into()));
}
}
break;
}
let copied = io::copy(&mut r.by_ref().take(size as u64), sink)?;
if copied < size as u64 {
return Err(Error::UnexpectedEof);
}
let mut crlf = [0u8; 2];
r.read_exact(&mut crlf)?;
if &crlf != b"\r\n" {
return Err(Error::BadResponse("missing CRLF after chunk".into()));
}
total += size as u64;
}
Ok(total)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_status_line_ok() {
let (v, s, r) = parse_status_line("HTTP/1.1 200 OK").unwrap();
assert_eq!(v, "HTTP/1.1");
assert_eq!(s, 200);
assert_eq!(r, "OK");
}
#[test]
fn method_override_uppercases() {
let req = Request::get("http://example.com/").unwrap().method("head");
assert_eq!(effective_method(&req), "HEAD");
let raw = Request::get("http://example.com/")
.unwrap()
.method("head")
.keep_method_case(true);
assert_eq!(effective_method(&raw), "head");
}
#[test]
fn parses_status_line_no_reason() {
let (_, s, r) = parse_status_line("HTTP/1.0 204").unwrap();
assert_eq!(s, 204);
assert_eq!(r, "");
}
#[test]
fn rejects_non_http() {
assert!(parse_status_line("RTSP/1.0 200 OK").is_err());
}
#[test]
fn header_name_token_validation() {
assert!(is_valid_header_name("X-Custom-Header"));
assert!(is_valid_header_name("Content-Type"));
assert!(!is_valid_header_name(""));
assert!(!is_valid_header_name("Bad Name")); assert!(!is_valid_header_name("Bad:Name")); assert!(!is_valid_header_name("Bad\r\nName"));
}
#[test]
fn validate_header_rejects_crlf_in_value() {
assert!(validate_header("X", "ok").is_ok());
assert!(validate_header("X", "evil\r\nInjected: 1").is_err());
assert!(validate_header("X", "evil\rstuff").is_err());
assert!(validate_header("X", "evil\nstuff").is_err());
assert!(validate_header("X", "evil\0stuff").is_err());
}
#[test]
fn validate_method_rejects_control_and_space() {
assert!(validate_method("GET").is_ok());
assert!(validate_method("PROPFIND").is_ok());
assert!(validate_method("").is_err());
assert!(validate_method("GET HTTP/1.1\r\nEvil:").is_err());
assert!(validate_method("BAD\r\n").is_err());
assert!(validate_method("BAD METHOD").is_err());
}
#[test]
fn write_request_refuses_injected_header() {
let mut req = Request::get("http://example.com/").unwrap();
req.headers
.push(("X-Evil".into(), "a\r\nInjected: 1".into()));
let mut sink = Vec::new();
let mut trace = Vec::new();
let err = write_request(&mut sink, &req, false, &mut trace).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
assert!(sink.is_empty(), "nothing should have been written");
}
#[test]
fn write_request_refuses_injected_method() {
let mut req = Request::get("http://example.com/").unwrap();
req.method = "GET\r\nEvil: 1".into();
let mut sink = Vec::new();
let mut trace = Vec::new();
assert!(write_request(&mut sink, &req, false, &mut trace).is_err());
assert!(sink.is_empty());
}
#[test]
fn http3_builder_methods_set_pref() {
let r = Request::get("https://example.com/").unwrap().http3();
assert_eq!(r.http_version_pref, HttpVersionPref::Http3);
let r = Request::get("https://example.com/").unwrap().http3_only();
assert_eq!(r.http_version_pref, HttpVersionPref::Http3Only);
let r = Request::get("https://example.com/")
.unwrap()
.http_version(HttpVersionPref::Http3);
assert_eq!(r.http_version_pref, HttpVersionPref::Http3);
}
#[test]
fn h3_fallback_classification() {
use std::io;
assert!(h3_should_fall_back(&Error::Io(io::Error::new(
io::ErrorKind::TimedOut,
"x"
))));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: feed: Decode".into()
)));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: connection closed mid-handshake".into()
)));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: peer closed connection".into()
)));
assert!(!h3_should_fall_back(&Error::BadResponse(
"qpack: dynamic index out of range".into()
)));
assert!(!h3_should_fall_back(&Error::H2NotNegotiated));
}
#[test]
fn http3_only_over_plaintext_http_errors() {
let req = Request::get("http://example.com/").unwrap().http3_only();
let mut trace = Vec::new();
let err = send_plain(req, &mut trace).unwrap_err();
assert!(matches!(err, Error::UnsupportedScheme(_)));
}
#[test]
fn tls_pool_eligible_only_for_default_posture() {
let mut req = Request::get("https://example.com/").unwrap();
assert!(tls_pool_eligible(&req)); req.verify_tls = false;
assert!(!tls_pool_eligible(&req)); req.verify_tls = true;
req.ca_bundle = Some("/tmp/ca.pem".into());
assert!(!tls_pool_eligible(&req)); req.ca_bundle = None;
req.ca_path = Some("/tmp/cadir".into());
assert!(!tls_pool_eligible(&req)); req.ca_path = None;
req.client_cert = Some("/tmp/client.pem".into());
assert!(!tls_pool_eligible(&req)); req.client_cert = None;
req.pinned_pubkey = Some("sha256//AAAA".into());
assert!(!tls_pool_eligible(&req)); req.pinned_pubkey = None;
req.crl_file = Some("/tmp/crl.pem".into());
assert!(!tls_pool_eligible(&req)); }
#[test]
fn pool_key_no_overrides_pools_together() {
let a = Request::get("http://example.com/a").unwrap();
let b = Request::get("http://example.com/b").unwrap();
let ka = pool_key_for(&a);
let kb = pool_key_for(&b);
assert_eq!(ka, kb);
assert_eq!(ka.effective_target, None);
}
#[test]
fn pool_key_connect_to_distinguishes_dial_target() {
let plain = Request::get("http://example.com/").unwrap();
let remapped = Request::get("http://example.com/").unwrap().connect_to(
"example.com",
80,
"10.0.0.1",
8080,
);
let kp = pool_key_for(&plain);
let kr = pool_key_for(&remapped);
assert_ne!(kp, kr);
assert_eq!(kr.host, "example.com");
assert_eq!(kr.port, 80);
assert_eq!(kr.effective_target, Some(("10.0.0.1".to_string(), 8080)));
assert_eq!(kp.effective_target, None);
}
#[test]
fn pool_key_resolve_distinguishes_pinned_ip() {
let plain = Request::get("http://example.com/").unwrap();
let to_a = Request::get("http://example.com/").unwrap().resolve_addr(
"example.com",
80,
"10.0.0.1".parse().unwrap(),
);
let to_b = Request::get("http://example.com/").unwrap().resolve_addr(
"example.com",
80,
"10.0.0.2".parse().unwrap(),
);
let kp = pool_key_for(&plain);
let ka = pool_key_for(&to_a);
let kb = pool_key_for(&to_b);
assert_ne!(ka, kb);
assert_ne!(ka, kp);
assert_ne!(kb, kp);
assert_eq!(ka.effective_target, Some(("10.0.0.1".to_string(), 80)));
assert_eq!(kb.effective_target, Some(("10.0.0.2".to_string(), 80)));
}
#[test]
fn pool_key_non_matching_override_is_transparent() {
let req = Request::get("http://example.com/")
.unwrap()
.connect_to("other.example", 80, "10.0.0.1", 9000)
.resolve_addr("other.example", 80, "10.0.0.2".parse().unwrap());
assert_eq!(pool_key_for(&req).effective_target, None);
let bare = Request::get("http://example.com/").unwrap();
assert_eq!(pool_key_for(&req), pool_key_for(&bare));
}
#[test]
fn effective_dial_target_resolve_follows_connect_to() {
let connect_to = vec![(
"example.com".to_string(),
80,
"backend".to_string(),
8080u16,
)];
let resolve = vec![("backend".to_string(), 8080u16, "10.0.0.5".parse().unwrap())];
let t = effective_dial_target(&connect_to, &resolve, "example.com", 80);
assert_eq!(t, Some(("10.0.0.5".to_string(), 8080)));
}
#[test]
fn tls_opts_from_wires_client_cert_and_pins() {
use std::io::Write;
let (leaf_der, key_pem) = crate::tls::client_auth::tests_support_ed25519_leaf();
let cert_pem = purecrypto::x509::Certificate::from_der(leaf_der.clone())
.unwrap()
.to_pem();
let dir = std::env::temp_dir();
let cert_path = dir.join(format!("rsurl_test_cert_{}.pem", std::process::id()));
let key_path = dir.join(format!("rsurl_test_key_{}.pem", std::process::id()));
std::fs::File::create(&cert_path)
.unwrap()
.write_all(cert_pem.as_bytes())
.unwrap();
std::fs::File::create(&key_path)
.unwrap()
.write_all(key_pem.as_bytes())
.unwrap();
let pin = crate::tls::client_auth::leaf_spki_sha256(&leaf_der).unwrap();
let b64 = pin_to_sha256_spec(&pin);
let req = Request::get("https://example.com/")
.unwrap()
.client_cert(cert_path.to_str().unwrap())
.client_key(key_path.to_str().unwrap())
.pinned_pubkey(&b64);
let opts = tls_opts_from(&req, &[]).expect("build TlsOpts");
assert!(opts.client_cert.is_some());
assert!(opts.client_key.is_some());
assert_eq!(opts.pinned_spki_sha256, vec![pin]);
let _ = std::fs::remove_file(&cert_path);
let _ = std::fs::remove_file(&key_path);
}
#[test]
fn partition_key_distinguishes_pool_keys() {
let a = Request::get("http://example.com/")
.unwrap()
.partition("siteA");
let b = Request::get("http://example.com/")
.unwrap()
.partition("siteB");
let bare = Request::get("http://example.com/").unwrap();
assert_ne!(pool_key_for(&a), pool_key_for(&b));
assert_ne!(pool_key_for(&a), pool_key_for(&bare));
let a2 = Request::get("http://example.com/")
.unwrap()
.partition("siteA");
assert_eq!(pool_key_for(&a), pool_key_for(&a2));
}
#[test]
fn tls_opts_from_carries_verify_callback() {
use crate::tls::{CertVerdict, VerifyCallback};
let req = Request::get("https://example.com/")
.unwrap()
.tls_verify_callback(VerifyCallback::new(|_| CertVerdict::Accept));
let opts = tls_opts_from(&req, &[]).expect("build TlsOpts");
assert!(
opts.verify_callback.is_some(),
"verify callback should reach TlsOpts"
);
}
fn pin_to_sha256_spec(hash: &[u8; 32]) -> String {
format!("sha256//{}", crate::websocket::base64_encode(hash))
}
#[test]
fn tls_opts_from_reads_crl_file() {
use std::io::Write;
let dir = std::env::temp_dir();
let path = dir.join(format!("rsurl_test_crl_{}.pem", std::process::id()));
let body = b"-----BEGIN X509 CRL-----\nMIIB\n-----END X509 CRL-----\n";
std::fs::File::create(&path)
.unwrap()
.write_all(body)
.unwrap();
let req = Request::get("https://example.com/")
.unwrap()
.crl_file(path.to_str().unwrap());
assert!(!tls_pool_eligible(&req)); let opts = tls_opts_from(&req, &[]).expect("build TlsOpts");
assert_eq!(opts.crl_pem.as_deref(), Some(body.as_slice()));
let _ = std::fs::remove_file(&path);
}
fn resp_with(content_type: Option<&str>, status: u16, body: &[u8]) -> Response {
let mut headers = Vec::new();
if let Some(ct) = content_type {
headers.push(("Content-Type".to_string(), ct.to_string()));
}
Response {
status,
reason: if status == 404 {
"Not Found".into()
} else {
"OK".into()
},
version: "HTTP/1.1".into(),
headers,
body: body.to_vec(),
timing: Timing::default(),
final_url: String::new(),
tls: None,
}
}
#[test]
fn response_text_decodes_charsets() {
assert_eq!(
resp_with(None, 200, "héllo".as_bytes()).text().unwrap(),
"héllo"
);
assert_eq!(
resp_with(Some("text/plain; Charset=\"UTF-8\""), 200, b"hi")
.text()
.unwrap(),
"hi"
);
let lossy = resp_with(None, 200, b"a\xffb").text().unwrap();
assert_eq!(lossy, "a\u{fffd}b");
assert_eq!(
resp_with(Some("text/plain; charset=iso-8859-1"), 200, &[b'c', 0xE9])
.text()
.unwrap(),
"cé"
);
let err = resp_with(Some("text/plain; charset=shift_jis"), 200, b"x").text();
assert!(matches!(err, Err(Error::Decode(_))), "got {err:?}");
}
#[test]
fn response_error_for_status() {
assert!(resp_with(None, 200, b"ok").error_for_status().is_ok());
assert!(resp_with(None, 302, b"").error_for_status().is_ok());
match resp_with(None, 404, b"nope").error_for_status() {
Err(Error::Status { code, reason }) => {
assert_eq!(code, 404);
assert_eq!(reason, "Not Found");
}
other => panic!("expected Status error, got {other:?}"),
}
assert!(matches!(
resp_with(None, 500, b"").error_for_status(),
Err(Error::Status { code: 500, .. })
));
}
#[cfg(feature = "json")]
#[test]
fn response_json_deserializes() {
let r = resp_with(Some("application/json"), 200, br#"{"a":1,"b":["x","y"]}"#);
let v: serde_json::Value = r.json().expect("parse json");
assert_eq!(v["a"], 1);
assert_eq!(v["b"][1], "y");
let bad = resp_with(Some("application/json"), 200, b"{not json");
assert!(matches!(
bad.json::<serde_json::Value>(),
Err(Error::Decode(_))
));
}
#[test]
fn content_length_single_ok() {
let h = vec![("Content-Length".to_string(), "42".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), Some(42));
}
#[test]
fn content_length_absent_is_none() {
let h = vec![("X".to_string(), "y".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), None);
}
#[test]
fn content_length_duplicate_agreeing_ok() {
let h = vec![
("Content-Length".to_string(), "5".to_string()),
("content-length".to_string(), "5".to_string()),
];
assert_eq!(parse_content_length(&h).unwrap(), Some(5));
}
#[test]
fn content_length_conflicting_rejected() {
let h = vec![
("Content-Length".to_string(), "5".to_string()),
("Content-Length".to_string(), "6".to_string()),
];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_comma_list_conflicting_rejected() {
let h = vec![("Content-Length".to_string(), "5, 6".to_string())];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_comma_list_agreeing_ok() {
let h = vec![("Content-Length".to_string(), "5, 5".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), Some(5));
}
#[test]
fn content_length_unparseable_rejected() {
let h = vec![("Content-Length".to_string(), "not-a-number".to_string())];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_signed_rejected() {
let h = vec![("Content-Length".to_string(), "+5".to_string())];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_plain_digits_ok() {
let h = vec![("Content-Length".to_string(), "5".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), Some(5));
}
#[test]
fn read_body_rejects_te_and_cl_together() {
use std::io::Cursor;
let headers = vec![
("Transfer-Encoding".to_string(), "chunked".to_string()),
("Content-Length".to_string(), "5".to_string()),
];
let mut r = BufReader::new(Cursor::new(b"0\r\n\r\n".to_vec()));
let err = read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn read_body_rejects_conflicting_content_length() {
use std::io::Cursor;
let headers = vec![
("Content-Length".to_string(), "3".to_string()),
("Content-Length".to_string(), "4".to_string()),
];
let mut r = BufReader::new(Cursor::new(b"abcd".to_vec()));
assert!(read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").is_err());
}
struct TruncReader {
inner: std::io::Cursor<Vec<u8>>,
truncated: bool,
}
impl Read for TruncReader {
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
self.inner.read(b)
}
}
impl BufRead for TruncReader {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.inner.fill_buf()
}
fn consume(&mut self, n: usize) {
self.inner.consume(n)
}
}
impl TruncationAware for TruncReader {
fn response_truncated(&self) -> bool {
self.truncated
}
}
#[test]
fn read_body_rejects_truncated_eof_delimited() {
let headers: Vec<(String, String)> = vec![];
let mut r = TruncReader {
inner: std::io::Cursor::new(b"partial body".to_vec()),
truncated: true,
};
let err = read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").unwrap_err();
assert!(matches!(err, Error::UnexpectedEof));
}
#[test]
fn read_body_accepts_clean_eof_delimited() {
let headers: Vec<(String, String)> = vec![];
let mut r = TruncReader {
inner: std::io::Cursor::new(b"complete body".to_vec()),
truncated: false,
};
let body = read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").unwrap();
assert_eq!(body, b"complete body");
}
#[test]
fn read_body_ignores_truncation_for_content_length() {
let headers = vec![("Content-Length".to_string(), "5".to_string())];
let mut r = TruncReader {
inner: std::io::Cursor::new(b"hello".to_vec()),
truncated: true,
};
let body = read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").unwrap();
assert_eq!(body, b"hello");
}
#[test]
fn stream_body_rejects_truncated_eof_delimited() {
let headers: Vec<(String, String)> = vec![];
let mut r = TruncReader {
inner: std::io::Cursor::new(b"partial".to_vec()),
truncated: true,
};
let mut sink: Vec<u8> = Vec::new();
let err = stream_body(&mut r, &mut sink, &headers, 200, "GET").unwrap_err();
assert!(matches!(err, Error::UnexpectedEof));
}
#[test]
fn read_line_capped_errors_on_overlong_line() {
use std::io::Cursor;
let huge = vec![b'a'; MAX_HEADER_BYTES + 4096];
let mut r = BufReader::new(Cursor::new(huge));
let mut buf = String::new();
let err = read_line_capped(&mut r, &mut buf, MAX_HEADER_BYTES).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_line_capped_reads_normal_line_and_eof() {
use std::io::Cursor;
let mut r = BufReader::new(Cursor::new(b"hello\r\nworld".to_vec()));
let mut a = String::new();
let n = read_line_capped(&mut r, &mut a, MAX_HEADER_BYTES).unwrap();
assert_eq!(n, 7);
assert_eq!(a, "hello\r\n");
let mut b = String::new();
let n = read_line_capped(&mut r, &mut b, MAX_HEADER_BYTES).unwrap();
assert_eq!(n, 5);
assert_eq!(b, "world");
let mut c = String::new();
let n = read_line_capped(&mut r, &mut c, MAX_HEADER_BYTES).unwrap();
assert_eq!(n, 0);
}
#[test]
fn read_response_rejects_overlong_status_line() {
use std::io::Cursor;
let huge = vec![b'X'; MAX_HEADER_BYTES + 4096];
let mut r = BufReader::new(Cursor::new(huge));
let mut trace = Vec::new();
let err = read_response(&mut r, "GET", true, &mut trace).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_response_rejects_overlong_header_line() {
use std::io::Cursor;
let mut bytes = b"HTTP/1.1 200 OK\r\n".to_vec();
bytes.extend(std::iter::repeat_n(b'a', MAX_HEADER_BYTES + 4096));
let mut r = BufReader::new(Cursor::new(bytes));
let mut trace = Vec::new();
let err = read_response(&mut r, "GET", true, &mut trace).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_chunked_rejects_overlong_size_line() {
use std::io::Cursor;
let huge = vec![b'f'; MAX_HEADER_BYTES + 4096];
let mut r = BufReader::new(Cursor::new(huge));
let err = read_chunked(&mut r).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_chunked_rejects_signed_chunk_size() {
use std::io::Cursor;
let mut r = BufReader::new(Cursor::new(b"+a\r\n".to_vec()));
let err = read_chunked(&mut r).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_chunked_rejects_internal_junk_chunk_size() {
use std::io::Cursor;
let mut r = BufReader::new(Cursor::new(b"a b\r\n".to_vec()));
let err = read_chunked(&mut r).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn read_chunked_accepts_plain_hex_chunk_size() {
use std::io::Cursor;
let mut payload = Vec::new();
payload.extend_from_slice(b"a\r\n"); payload.extend_from_slice(b"0123456789\r\n");
payload.extend_from_slice(b"1f\r\n"); payload.extend_from_slice(&[b'x'; 31]);
payload.extend_from_slice(b"\r\n");
payload.extend_from_slice(b"0\r\n\r\n");
let mut r = BufReader::new(Cursor::new(payload));
let body = read_chunked(&mut r).unwrap();
assert_eq!(body.len(), 10 + 31);
assert_eq!(&body[..10], b"0123456789");
assert!(body[10..].iter().all(|&b| b == b'x'));
}
#[test]
fn read_chunked_rejects_oversized_trailer_block() {
use std::io::Cursor;
let mut payload = Vec::new();
payload.extend_from_slice(b"0\r\n");
let line = b"X-Trailer: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n";
let reps = (MAX_HEADER_BYTES / line.len()) + 16;
for _ in 0..reps {
payload.extend_from_slice(line);
}
let mut r = BufReader::new(Cursor::new(payload));
let err = read_chunked(&mut r).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn stream_chunked_rejects_oversized_trailer_block() {
use std::io::Cursor;
let mut payload = Vec::new();
payload.extend_from_slice(b"0\r\n");
let line = b"X-Trailer: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n";
let reps = (MAX_HEADER_BYTES / line.len()) + 16;
for _ in 0..reps {
payload.extend_from_slice(line);
}
let mut r = BufReader::new(Cursor::new(payload));
let mut sink = Vec::new();
let err = stream_chunked(&mut r, &mut sink).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)), "got {err:?}");
}
#[test]
fn digest_scheme_detection_is_char_boundary_safe() {
fn is_digest(chal: &str) -> bool {
let scheme = chal.trim_start();
scheme
.as_bytes()
.get(..6)
.is_some_and(|b| b.eq_ignore_ascii_case(b"digest"))
}
assert!(!is_digest("Digé realm=\"x\""));
assert!(!is_digest("é"));
assert!(!is_digest("\u{1f600}abcdef"));
assert!(is_digest("Digest realm=\"x\""));
assert!(is_digest(" digest realm=\"x\""));
assert!(!is_digest("Basic realm=\"x\""));
assert!(!is_digest("Dig"));
}
#[test]
fn proxy_parse_basic() {
let p = ProxyConfig::parse("http://proxy.example:3128").unwrap();
assert_eq!(p.host, "proxy.example");
assert_eq!(p.port, 3128);
assert!(p.auth.is_none());
}
#[test]
fn proxy_parse_with_creds() {
let p = ProxyConfig::parse("http://alice:hunter2@proxy:8080").unwrap();
assert_eq!(p.host, "proxy");
assert_eq!(p.port, 8080);
assert_eq!(p.auth.as_ref().unwrap().0, "alice");
assert_eq!(p.auth.as_ref().unwrap().1, "hunter2");
}
#[test]
fn proxy_parse_bare_hostport_is_http() {
let p = ProxyConfig::parse("proxy.local:8080").unwrap();
assert_eq!(p.host, "proxy.local");
assert_eq!(p.port, 8080);
}
#[test]
fn proxy_parse_rejects_https() {
let err = ProxyConfig::parse("https://proxy:443").unwrap_err();
matches!(err, Error::UnsupportedScheme(_));
}
#[test]
fn proxy_bypass_matches_suffix() {
let mut req = Request::get("http://api.example.com/x").unwrap();
req.proxy = Some(ProxyConfig::parse("http://proxy:8080").unwrap());
req.no_proxy = vec!["example.com".into()];
assert!(proxy_bypassed(&req));
req.url = Url::parse("http://other.org/x").unwrap();
assert!(!proxy_bypassed(&req));
}
#[test]
fn proxy_bypass_wildcard() {
let mut req = Request::get("http://anywhere/").unwrap();
req.proxy = Some(ProxyConfig::parse("http://p:1").unwrap());
req.no_proxy = vec!["*".into()];
assert!(proxy_bypassed(&req));
}
#[test]
fn connect_tunnel_happy_path() {
use std::io::{self, Cursor};
struct Mock {
written: Vec<u8>,
reply: Cursor<Vec<u8>>,
trailing: Vec<u8>, }
impl Read for Mock {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.reply.read(buf)?;
if n == 0 && !self.trailing.is_empty() {
let take = buf.len().min(self.trailing.len());
buf[..take].copy_from_slice(&self.trailing[..take]);
self.trailing.drain(..take);
return Ok(take);
}
Ok(n)
}
}
impl Write for Mock {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut mock = Mock {
written: Vec::new(),
reply: Cursor::new(b"HTTP/1.1 200 Connection established\r\n\r\n".to_vec()),
trailing: vec![0x16],
};
let target = Url::parse("https://origin.example:443/").unwrap();
let proxy = ProxyConfig {
host: "proxy".into(),
port: 3128,
auth: Some(("u".into(), "p".into())),
};
connect_tunnel(&mut mock, &target, &proxy, &mut io::sink()).unwrap();
let written = String::from_utf8(mock.written.clone()).unwrap();
assert!(
written.starts_with("CONNECT origin.example:443 HTTP/1.1\r\n"),
"request line missing: {written:?}",
);
assert!(
written.contains("Host: origin.example:443\r\n"),
"Host header missing: {written:?}",
);
assert!(
written.contains("Proxy-Authorization: Basic dTpw\r\n"),
"auth header missing or wrong: {written:?}",
);
let mut byte = [0u8; 1];
assert_eq!(mock.read(&mut byte).unwrap(), 1);
assert_eq!(byte[0], 0x16, "next-layer byte was consumed by the tunnel");
}
#[test]
fn connect_tunnel_reports_407() {
use std::io::{self, Cursor};
let payload =
b"HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic\r\n\r\n";
let mut mock = std::io::Cursor::new(Vec::new());
struct RW<'a> {
inner: Cursor<&'a [u8]>,
sink: &'a mut Vec<u8>,
}
impl<'a> Read for RW<'a> {
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
self.inner.read(b)
}
}
impl<'a> Write for RW<'a> {
fn write(&mut self, b: &[u8]) -> io::Result<usize> {
self.sink.extend_from_slice(b);
Ok(b.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut sink = Vec::new();
let mut rw = RW {
inner: Cursor::new(payload),
sink: &mut sink,
};
let target = Url::parse("https://origin/").unwrap();
let proxy = ProxyConfig {
host: "p".into(),
port: 1,
auth: None,
};
let err = connect_tunnel(&mut rw, &target, &proxy, &mut io::sink()).unwrap_err();
match err {
Error::BadResponse(msg) => assert!(msg.contains("407"), "got {msg:?}"),
other => panic!("unexpected: {other:?}"),
}
let _ = mock.write(&[]);
}
}