use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::TcpStream;
use std::time::Duration;
use crate::error::{Error, Result};
use crate::url::Url;
const DEFAULT_USER_AGENT: &str = concat!("rsurl/", env!("CARGO_PKG_VERSION"));
const MAX_HEADER_BYTES: usize = 64 * 1024;
pub(crate) const MAX_BODY_BYTES: usize = 256 * 1024 * 1024;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum HttpVersionPref {
#[default]
Auto,
Http11Only,
Http2Only,
Http3,
Http3Only,
}
#[derive(Debug, Clone)]
pub struct Request {
pub(crate) method: String,
pub(crate) url: Url,
pub(crate) headers: Vec<(String, String)>,
pub(crate) body: Vec<u8>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) read_timeout: Option<Duration>,
pub(crate) http_version_pref: HttpVersionPref,
pub(crate) follow_redirects: bool,
pub(crate) max_redirs: u32,
pub(crate) basic_auth: Option<(String, String)>,
pub(crate) verify_tls: bool,
pub(crate) ca_bundle: Option<String>,
pub(crate) max_time: Option<Duration>,
pub(crate) proxy: Option<ProxyConfig>,
pub(crate) no_proxy: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub host: String,
pub port: u16,
pub auth: Option<(String, String)>,
}
impl ProxyConfig {
pub fn parse(s: &str) -> Result<Self> {
let normalised: String = if s.contains("://") {
s.to_string()
} else {
format!("http://{s}")
};
let u = Url::parse(&normalised)?;
if u.scheme != "http" {
return Err(Error::UnsupportedScheme(format!(
"proxy scheme {:?} not supported (only http:// at this milestone)",
u.scheme
)));
}
let auth = u
.userinfo
.as_deref()
.map(|info| match info.split_once(':') {
Some((u, p)) => (u.to_string(), p.to_string()),
None => (info.to_string(), String::new()),
});
Ok(ProxyConfig {
host: u.host.clone(),
port: u.port,
auth,
})
}
}
impl Request {
pub fn new(method: &str, url: &str) -> Result<Self> {
Ok(Request {
method: method.to_ascii_uppercase(),
url: Url::parse(url)?,
headers: Vec::new(),
body: Vec::new(),
connect_timeout: Some(Duration::from_secs(30)),
read_timeout: Some(Duration::from_secs(60)),
http_version_pref: HttpVersionPref::Auto,
follow_redirects: false,
max_redirs: 50,
basic_auth: None,
verify_tls: true,
ca_bundle: None,
max_time: None,
proxy: None,
no_proxy: Vec::new(),
})
}
pub fn get(url: &str) -> Result<Self> {
Self::new("GET", url)
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.headers.push((name.to_string(), value.to_string()));
self
}
pub fn body<B: Into<Vec<u8>>>(mut self, body: B) -> Self {
self.body = body.into();
self
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn http_version(mut self, pref: HttpVersionPref) -> Self {
self.http_version_pref = pref;
self
}
pub fn http2_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http2Only;
self
}
pub fn http11_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http11Only;
self
}
pub fn http3(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http3;
self
}
pub fn http3_only(mut self) -> Self {
self.http_version_pref = HttpVersionPref::Http3Only;
self
}
pub fn follow_redirects(mut self, on: bool) -> Self {
self.follow_redirects = on;
self
}
pub fn max_redirs(mut self, n: u32) -> Self {
self.max_redirs = n;
self
}
pub fn basic_auth(mut self, user: &str, pass: &str) -> Self {
self.basic_auth = Some((user.to_string(), pass.to_string()));
self
}
pub fn verify_tls(mut self, on: bool) -> Self {
self.verify_tls = on;
self
}
pub fn ca_bundle(mut self, path: &str) -> Self {
self.ca_bundle = Some(path.to_string());
self
}
pub fn max_time(mut self, d: Duration) -> Self {
self.max_time = Some(d);
self
}
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.connect_timeout = Some(d);
self
}
pub fn proxy(mut self, spec: &str) -> Result<Self> {
self.proxy = Some(ProxyConfig::parse(spec)?);
Ok(self)
}
pub fn proxy_user(mut self, user: &str, pass: &str) -> Result<Self> {
match self.proxy.as_mut() {
Some(p) => {
p.auth = Some((user.to_string(), pass.to_string()));
Ok(self)
}
None => Err(Error::BadResponse(
"proxy_user called without a proxy set".into(),
)),
}
}
pub fn no_proxy<I, S>(mut self, entries: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.no_proxy = entries.into_iter().map(Into::into).collect();
self
}
pub fn send(self) -> Result<Response> {
self.send_to(&mut io::sink(), None)
}
pub fn send_traced(self, trace: &mut dyn Write) -> Result<Response> {
self.send_to(trace, None)
}
pub fn send_with_jar(self, jar: &mut crate::cookie::CookieJar) -> Result<Response> {
self.send_to(&mut io::sink(), Some(jar))
}
pub fn send_traced_with_jar(
self,
jar: &mut crate::cookie::CookieJar,
trace: &mut dyn Write,
) -> Result<Response> {
self.send_to(trace, Some(jar))
}
fn send_once(self, trace: &mut dyn Write) -> Result<Response> {
if !self.verify_tls && self.url.scheme == "https" {
let _ = writeln!(trace, "* WARNING: certificate verification disabled (-k)");
}
match self.url.scheme.as_str() {
"http" => send_plain(self, trace),
"https" => send_https(self, trace),
other => Err(Error::UnsupportedScheme(other.to_string())),
}
}
fn send_to(
self,
trace: &mut dyn Write,
mut jar: Option<&mut crate::cookie::CookieJar>,
) -> Result<Response> {
let mut req = self;
let deadline = req.max_time.map(|d| std::time::Instant::now() + d);
let mut hops_left = req.max_redirs;
loop {
if let Some(end) = deadline {
if std::time::Instant::now() >= end {
return Err(Error::BadResponse("operation timed out".into()));
}
}
let mut snapshot = req.clone();
if let Some(j) = jar.as_deref_mut() {
j.purge_expired();
snapshot
.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("cookie"));
if let Some(val) = j.cookie_header(&snapshot.url) {
snapshot.headers.push(("Cookie".to_string(), val));
}
}
let resp = snapshot.send_once(trace)?;
if let Some(j) = jar.as_deref_mut() {
j.ingest_response(&req.url, &resp.headers);
}
if !req.follow_redirects || !is_redirect_status(resp.status) {
return Ok(resp);
}
if hops_left == 0 {
return Err(Error::BadResponse(format!(
"maximum ({}) redirects followed",
req.max_redirs
)));
}
let location = match resp.header("location") {
Some(l) => l.to_string(),
None => return Ok(resp), };
let next_url = crate::url::resolve(&req.url, &location)?;
let _ = writeln!(
trace,
"* Following redirect to {}",
url_to_string(&next_url)
);
let host_changed = next_url.host != req.url.host
|| next_url.port != req.url.port
|| next_url.scheme != req.url.scheme;
let prev_method = req.method.clone();
let prev_body = std::mem::take(&mut req.body);
let mut next = req;
next.url = next_url;
if host_changed {
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("authorization") && !k.eq_ignore_ascii_case("cookie")
});
next.basic_auth = None;
}
if (301..=303).contains(&resp.status)
&& !prev_method.eq_ignore_ascii_case("GET")
&& !prev_method.eq_ignore_ascii_case("HEAD")
{
next.method = "GET".to_string();
next.headers.retain(|(k, _)| {
!k.eq_ignore_ascii_case("content-type")
&& !k.eq_ignore_ascii_case("content-length")
&& !k.eq_ignore_ascii_case("transfer-encoding")
});
} else {
next.body = prev_body;
}
hops_left -= 1;
req = next;
}
}
}
fn is_redirect_status(status: u16) -> bool {
matches!(status, 301 | 302 | 303 | 307 | 308)
}
fn url_to_string(u: &Url) -> String {
let default = matches!((u.scheme.as_str(), u.port), ("http", 80) | ("https", 443));
if default {
format!("{}://{}{}", u.scheme, u.host, u.path)
} else {
format!("{}://{}:{}{}", u.scheme, u.host, u.port, u.path)
}
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: u16,
pub reason: String,
pub version: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl Response {
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
}
fn send_plain(req: Request, trace: &mut dyn Write) -> Result<Response> {
match req.http_version_pref {
HttpVersionPref::Http3Only => {
return Err(Error::UnsupportedScheme(
"http/3 requires https://, not http://".into(),
));
}
HttpVersionPref::Http3 => {
let _ = writeln!(
trace,
"* HTTP/3 requested but URL is http://; using HTTP/1.1 (h3 needs https)"
);
}
_ => {}
}
let direct = req.proxy.is_none() || proxy_bypassed(&req);
if direct {
if let Some(bufrd) = pool_checkout_plain(&req.url) {
let _ = writeln!(trace, "* Reusing existing connection from pool");
match perform_on_pooled_plain(bufrd, &req, trace) {
Ok(resp) => return Ok(resp),
Err(PooledError::Stale(why)) => {
let _ = writeln!(trace, "* Pooled connection unusable ({why}); reconnecting");
}
Err(PooledError::Hard(e)) => return Err(e),
}
}
}
send_plain_fresh(req, direct, trace)
}
fn send_plain_fresh(req: Request, may_pool: bool, trace: &mut dyn Write) -> Result<Response> {
let stream = tcp_connect(&req, trace)?;
let mut bufrd = BufReader::new(stream);
write_request(bufrd.get_mut(), &req, via_plain_http_proxy(&req), trace)?;
let resp = read_response(&mut bufrd, &req.method, trace)?;
finalize_plain(bufrd, &req, &resp, may_pool, trace);
Ok(resp)
}
fn perform_on_pooled_plain(
mut bufrd: BufReader<TcpStream>,
req: &Request,
trace: &mut dyn Write,
) -> std::result::Result<Response, PooledError> {
if let Err(e) = write_request(bufrd.get_mut(), req, via_plain_http_proxy(req), trace) {
return Err(stale_or_hard(e));
}
let resp = match read_response(&mut bufrd, &req.method, trace) {
Ok(r) => r,
Err(e) => return Err(stale_or_hard(e)),
};
finalize_plain(bufrd, req, &resp, true, trace);
Ok(resp)
}
fn finalize_plain(
bufrd: BufReader<TcpStream>,
req: &Request,
resp: &Response,
may_pool: bool,
trace: &mut dyn Write,
) {
if may_pool && response_is_reusable(&req.method, resp) {
crate::pool::plain()
.lock()
.unwrap_or_else(|e| e.into_inner())
.release(pool_key_for(&req.url), bufrd);
let _ = writeln!(trace, "* Connection kept alive (pooled)");
} else {
let _ = writeln!(trace, "* Connection closed");
}
}
enum PooledError {
Stale(String),
Hard(Error),
}
fn stale_or_hard(e: Error) -> PooledError {
match &e {
Error::UnexpectedEof => PooledError::Stale("connection closed by peer".into()),
Error::Io(io_err) => match io_err.kind() {
io::ErrorKind::UnexpectedEof
| io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotConnected => PooledError::Stale(io_err.to_string()),
_ => PooledError::Hard(e),
},
_ => PooledError::Hard(e),
}
}
pub(crate) fn pool_key_for(u: &Url) -> crate::pool::Key {
crate::pool::Key {
scheme: u.scheme.clone(),
host: u.host.clone(),
port: u.port,
}
}
fn pool_checkout_plain(u: &Url) -> Option<BufReader<TcpStream>> {
crate::pool::plain()
.lock()
.unwrap_or_else(|e| e.into_inner())
.checkout(&pool_key_for(u))
}
pub(crate) fn tls_pool_eligible(req: &Request) -> bool {
req.verify_tls && req.ca_bundle.is_none()
}
fn pool_checkout_tls(req: &Request) -> Option<BufReader<crate::tls::TlsStream<TcpStream>>> {
if !tls_pool_eligible(req) {
return None;
}
crate::pool::tls()
.lock()
.unwrap_or_else(|e| e.into_inner())
.checkout(&pool_key_for(&req.url))
}
fn response_is_reusable(method: &str, resp: &Response) -> bool {
let conn_close = resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("connection")
&& v.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("close"))
});
if conn_close {
return false;
}
let has_framing = resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("content-length")
|| (k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked"))
});
let no_body_allowed = method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&resp.status)
|| resp.status == 204
|| resp.status == 304;
if !has_framing && !no_body_allowed {
return false;
}
if resp.version == "HTTP/1.1" {
true
} else {
resp.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("connection")
&& v.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("keep-alive"))
})
}
}
pub(crate) fn via_plain_http_proxy(req: &Request) -> bool {
if req.url.scheme != "http" {
return false;
}
match &req.proxy {
Some(_) => !proxy_bypassed(req),
None => false,
}
}
pub(crate) fn proxy_bypassed(req: &Request) -> bool {
if req.no_proxy.iter().any(|e| e.trim() == "*") {
return true;
}
let h = req.url.host.to_ascii_lowercase();
req.no_proxy.iter().any(|e| {
let e = e.trim().trim_start_matches('.').to_ascii_lowercase();
if e.is_empty() {
return false;
}
h == e || h.ends_with(&format!(".{e}"))
})
}
pub(crate) fn effective_basic_auth(req: &Request) -> Option<String> {
let (user, pass) = match &req.basic_auth {
Some((u, p)) => (u.clone(), p.clone()),
None => {
let info = req.url.userinfo.as_deref()?;
match info.split_once(':') {
Some((u, p)) => (u.to_string(), p.to_string()),
None => (info.to_string(), String::new()),
}
}
};
if user.is_empty() && pass.is_empty() {
return None;
}
let combined = format!("{user}:{pass}");
Some(crate::websocket::base64_encode(combined.as_bytes()))
}
pub(crate) fn tls_opts_from(req: &Request, alpn: &[&[u8]]) -> Result<crate::tls::TlsOpts> {
let mut opts = crate::tls::TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
opts.verify = req.verify_tls;
if let Some(path) = &req.ca_bundle {
opts.roots = Some(crate::tls::load_roots_from_file(path)?);
}
Ok(opts)
}
pub(crate) fn h3_should_fall_back(e: &Error) -> bool {
match e {
Error::Io(_) => true,
Error::UnsupportedScheme(_) | Error::InvalidUrl(_) => true,
Error::BadResponse(m) => {
m.starts_with("http3: connection closed")
|| m.starts_with("http3: peer closed")
|| m.starts_with("http3: build client")
|| m.starts_with("http3: open_bidi")
|| m.starts_with("http3: open_uni")
|| m.starts_with("http3: feed")
|| m.starts_with("http3: stream read")
|| m.starts_with("http3: stream write")
|| m.starts_with("http3: stream finish")
}
_ => false,
}
}
pub fn send_multiplexed(reqs: Vec<Request>, trace: &mut dyn Write) -> Vec<Result<Response>> {
crate::http2::send_multiplexed(reqs, trace)
}
fn send_https(req: Request, trace: &mut dyn Write) -> Result<Response> {
match req.http_version_pref {
HttpVersionPref::Http3Only => {
let _ = writeln!(trace, "* Trying HTTP/3 (QUIC), required (--http3-only)");
return crate::http3::send(req, trace);
}
HttpVersionPref::Http3 => {
let _ = writeln!(trace, "* Trying HTTP/3 (QUIC)...");
match crate::http3::send(req.clone(), trace) {
Ok(resp) => return Ok(resp),
Err(e) if h3_should_fall_back(&e) => {
let _ = writeln!(trace, "* HTTP/3 failed ({e}), falling back to HTTP/2/1.1");
}
Err(e) => return Err(e),
}
}
_ => {}
}
match req.http_version_pref {
HttpVersionPref::Http2Only => {
let _ = writeln!(trace, "* HTTP/2 required (--http2)");
return crate::http2::send(req, trace);
}
HttpVersionPref::Auto | HttpVersionPref::Http3 => {
let _ = writeln!(trace, "* Trying HTTP/2 via ALPN (h2)");
match crate::http2::send(req.clone(), trace) {
Ok(resp) => return Ok(resp),
Err(Error::H2NotNegotiated) => {
let _ = writeln!(
trace,
"* ALPN: server did not select h2, falling back to HTTP/1.1"
);
}
Err(e) => return Err(e),
}
}
HttpVersionPref::Http11Only => {
let _ = writeln!(trace, "* HTTP/1.1 forced (--http1.1)");
}
HttpVersionPref::Http3Only => unreachable!("Http3Only handled above"),
}
let direct = req.proxy.is_none() || proxy_bypassed(&req);
if direct {
if let Some(bufrd) = pool_checkout_tls(&req) {
let _ = writeln!(trace, "* Reusing existing connection from pool");
match perform_on_pooled_tls(bufrd, &req, trace) {
Ok(resp) => return Ok(resp),
Err(PooledError::Stale(why)) => {
let _ = writeln!(trace, "* Pooled connection unusable ({why}); reconnecting");
}
Err(PooledError::Hard(e)) => return Err(e),
}
}
}
send_https_fresh(req, direct, trace)
}
fn send_https_fresh(req: Request, may_pool: bool, trace: &mut dyn Write) -> Result<Response> {
let tcp = tcp_connect(&req, trace)?;
if let Some(p) = req
.proxy
.as_ref()
.filter(|_| !proxy_bypassed(&req) && req.url.scheme == "https")
{
connect_tunnel(&tcp, &req.url, p, trace)?;
}
let opts = tls_opts_from(&req, &[])?;
let tls = crate::tls::connect_over_tls(tcp, &req.url.host, opts)?;
write_tls_info(&tls, trace);
let mut bufrd = BufReader::new(tls);
write_request(bufrd.get_mut(), &req, false, trace)?;
let resp = read_response(&mut bufrd, &req.method, trace)?;
finalize_tls(bufrd, &req, &resp, may_pool, trace);
Ok(resp)
}
fn perform_on_pooled_tls(
mut bufrd: BufReader<crate::tls::TlsStream<TcpStream>>,
req: &Request,
trace: &mut dyn Write,
) -> std::result::Result<Response, PooledError> {
if let Err(e) = write_request(bufrd.get_mut(), req, false, trace) {
return Err(stale_or_hard(e));
}
let resp = match read_response(&mut bufrd, &req.method, trace) {
Ok(r) => r,
Err(e) => return Err(stale_or_hard(e)),
};
finalize_tls(bufrd, req, &resp, true, trace);
Ok(resp)
}
fn finalize_tls(
bufrd: BufReader<crate::tls::TlsStream<TcpStream>>,
req: &Request,
resp: &Response,
may_pool: bool,
trace: &mut dyn Write,
) {
if may_pool && tls_pool_eligible(req) && response_is_reusable(&req.method, resp) {
crate::pool::tls()
.lock()
.unwrap_or_else(|e| e.into_inner())
.release(pool_key_for(&req.url), bufrd);
let _ = writeln!(trace, "* Connection kept alive (pooled)");
} else {
let _ = writeln!(trace, "* Connection closed");
}
}
pub(crate) fn tcp_connect(req: &Request, trace: &mut dyn Write) -> Result<TcpStream> {
let proxy = req.proxy.as_ref().filter(|_| !proxy_bypassed(req));
let (target_host, target_port, via_proxy_label) = match proxy {
Some(p) => (p.host.as_str(), p.port, true),
None => (req.url.host.as_str(), req.url.port, false),
};
let addr = format!("{target_host}:{target_port}");
let first = std::net::ToSocketAddrs::to_socket_addrs(&addr)?
.next()
.ok_or_else(|| Error::InvalidUrl(target_host.to_string()))?;
let _ = writeln!(trace, "* Trying {first}...");
let stream = match req.connect_timeout {
Some(t) => TcpStream::connect_timeout(&first, t)?,
None => TcpStream::connect(first)?,
};
let peer = stream.peer_addr().unwrap_or(first);
if via_proxy_label {
let _ = writeln!(
trace,
"* Connected to proxy {} ({}) port {}",
target_host,
peer.ip(),
peer.port()
);
} else {
let _ = writeln!(
trace,
"* Connected to {} ({}) port {}",
req.url.host,
peer.ip(),
peer.port()
);
}
stream.set_read_timeout(req.read_timeout)?;
stream.set_write_timeout(req.read_timeout)?;
Ok(stream)
}
pub(crate) fn connect_tunnel<S: Read + Write>(
mut stream: S,
target: &Url,
proxy: &ProxyConfig,
trace: &mut dyn Write,
) -> Result<()> {
let host_port = format!("{}:{}", target.host, target.port);
let mut buf = Vec::with_capacity(256);
write!(&mut buf, "CONNECT {host_port} HTTP/1.1\r\n")?;
write!(&mut buf, "Host: {host_port}\r\n")?;
write!(&mut buf, "User-Agent: {DEFAULT_USER_AGENT}\r\n")?;
write!(&mut buf, "Proxy-Connection: Keep-Alive\r\n")?;
if let Some((user, pass)) = &proxy.auth {
let combined = format!("{user}:{pass}");
let creds = crate::websocket::base64_encode(combined.as_bytes());
write!(&mut buf, "Proxy-Authorization: Basic {creds}\r\n")?;
}
write!(&mut buf, "\r\n")?;
let head = String::from_utf8_lossy(&buf);
let head_no_final_crlf = head.strip_suffix("\r\n").unwrap_or(&head);
for line in head_no_final_crlf.split("\r\n") {
let _ = writeln!(trace, "> {line}");
}
stream.write_all(&buf)?;
stream.flush()?;
let mut line_buf: Vec<u8> = Vec::with_capacity(128);
let mut byte = [0u8; 1];
let mut status_line: Option<String> = None;
let mut total = 0usize;
loop {
if total > MAX_HEADER_BYTES {
return Err(Error::BadResponse(
"CONNECT response headers exceed 64 KiB".into(),
));
}
let n = stream.read(&mut byte)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
total += 1;
line_buf.push(byte[0]);
if byte[0] == b'\n' {
let trimmed_owned = String::from_utf8_lossy(
line_buf
.strip_suffix(b"\n")
.unwrap_or(&line_buf)
.strip_suffix(b"\r")
.unwrap_or(line_buf.strip_suffix(b"\n").unwrap_or(&line_buf)),
)
.into_owned();
let _ = writeln!(trace, "< {trimmed_owned}");
if status_line.is_none() {
status_line = Some(trimmed_owned.clone());
}
if trimmed_owned.is_empty() {
break;
}
line_buf.clear();
}
}
let status = status_line.ok_or_else(|| Error::BadResponse("CONNECT: no status line".into()))?;
let parts: Vec<&str> = status.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(Error::BadResponse(format!(
"CONNECT: malformed status line {status:?}"
)));
}
let code: u16 = parts[1]
.parse()
.map_err(|_| Error::BadResponse(format!("CONNECT: bad status {:?}", parts[1])))?;
if !(200..300).contains(&code) {
return Err(Error::BadResponse(format!(
"CONNECT to {host_port} failed: {status}"
)));
}
let _ = writeln!(trace, "* CONNECT tunnel established to {host_port}");
Ok(())
}
pub(crate) fn write_tls_info<S: Read + Write>(
tls: &crate::tls::TlsStream<S>,
trace: &mut dyn Write,
) {
if let Some(v) = tls.negotiated_version() {
let _ = writeln!(trace, "* SSL connection using {v:?}");
}
match tls.alpn_selected() {
Some(p) => {
let _ = writeln!(
trace,
"* ALPN: server accepted {}",
String::from_utf8_lossy(p)
);
}
None => {
let _ = writeln!(trace, "* ALPN: no protocol negotiated");
}
}
let certs = tls.peer_certificates();
let _ = writeln!(trace, "* Server certificate chain: {} cert(s)", certs.len());
for (i, der) in certs.iter().enumerate() {
match purecrypto::x509::Certificate::from_der(der.clone()) {
Ok(cert) => {
let subject = cert
.subject()
.ok()
.and_then(|d| d.common_name)
.unwrap_or_else(|| "?".into());
let issuer = cert
.issuer()
.ok()
.and_then(|d| d.common_name)
.unwrap_or_else(|| "?".into());
let _ = writeln!(trace, "* [{i}] subject CN: {subject}");
let _ = writeln!(trace, "* issuer CN: {issuer}");
if let Ok(v) = cert.validity() {
let _ = writeln!(
trace,
"* valid: {} -> {}",
v.not_before.as_str(),
v.not_after.as_str()
);
}
}
Err(_) => {
let _ = writeln!(trace, "* [{i}] (DER unparseable, {} bytes)", der.len());
}
}
}
}
fn is_valid_header_name(name: &str) -> bool {
!name.is_empty()
&& name.bytes().all(|b| {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
)
})
}
fn header_value_has_forbidden(value: &str) -> bool {
value.bytes().any(|b| b == b'\r' || b == b'\n' || b == 0)
}
fn validate_header(name: &str, value: &str) -> Result<()> {
if !is_valid_header_name(name) {
return Err(Error::BadResponse(format!("invalid header name: {name:?}")));
}
if header_value_has_forbidden(value) {
return Err(Error::BadResponse(format!(
"invalid header value for {name:?}"
)));
}
Ok(())
}
fn validate_method(method: &str) -> Result<()> {
if method.is_empty() || method.bytes().any(|b| b < 0x20 || b == 0x7f || b == b' ') {
return Err(Error::BadResponse(format!("invalid method: {method:?}")));
}
Ok(())
}
fn write_request<W: Write>(
mut w: W,
req: &Request,
absolute_form: bool,
trace: &mut dyn Write,
) -> Result<()> {
validate_method(&req.method)?;
for (k, v) in &req.headers {
validate_header(k, v)?;
}
let host_header = if (req.url.scheme == "http" && req.url.port == 80)
|| (req.url.scheme == "https" && req.url.port == 443)
{
req.url.host.clone()
} else {
format!("{}:{}", req.url.host, req.url.port)
};
let mut buf = Vec::with_capacity(256);
if absolute_form {
let target = if (req.url.scheme == "http" && req.url.port == 80)
|| (req.url.scheme == "https" && req.url.port == 443)
{
format!("{}://{}{}", req.url.scheme, req.url.host, req.url.path)
} else {
format!(
"{}://{}:{}{}",
req.url.scheme, req.url.host, req.url.port, req.url.path
)
};
write!(&mut buf, "{} {target} HTTP/1.1\r\n", req.method)?;
} else {
write!(&mut buf, "{} {} HTTP/1.1\r\n", req.method, req.url.path)?;
}
write!(&mut buf, "Host: {host_header}\r\n")?;
if absolute_form {
if let Some(p) = &req.proxy {
if let Some((user, pass)) = &p.auth {
let combined = format!("{user}:{pass}");
let creds = crate::websocket::base64_encode(combined.as_bytes());
write!(&mut buf, "Proxy-Authorization: Basic {creds}\r\n")?;
}
}
}
let mut have_ua = false;
let mut have_accept = false;
let mut have_accept_enc = false;
let mut have_clen = false;
let mut have_auth = false;
for (k, v) in &req.headers {
if k.eq_ignore_ascii_case("user-agent") {
have_ua = true;
}
if k.eq_ignore_ascii_case("accept") {
have_accept = true;
}
if k.eq_ignore_ascii_case("accept-encoding") {
have_accept_enc = true;
}
if k.eq_ignore_ascii_case("content-length") {
have_clen = true;
}
if k.eq_ignore_ascii_case("authorization") {
have_auth = true;
}
write!(&mut buf, "{k}: {v}\r\n")?;
}
if !have_auth {
if let Some(creds) = effective_basic_auth(req) {
write!(&mut buf, "Authorization: Basic {creds}\r\n")?;
}
}
if !have_ua {
write!(&mut buf, "User-Agent: {DEFAULT_USER_AGENT}\r\n")?;
}
if !have_accept {
write!(&mut buf, "Accept: */*\r\n")?;
}
if !have_accept_enc {
write!(&mut buf, "Accept-Encoding: gzip, deflate\r\n")?;
}
if !req.body.is_empty() && !have_clen {
write!(&mut buf, "Content-Length: {}\r\n", req.body.len())?;
}
write!(&mut buf, "\r\n")?;
let head = String::from_utf8_lossy(&buf);
let head_no_final_crlf = head.strip_suffix("\r\n").unwrap_or(&head);
for line in head_no_final_crlf.split("\r\n") {
let _ = writeln!(trace, "> {line}");
}
w.write_all(&buf)?;
if !req.body.is_empty() {
let _ = writeln!(trace, "* uploading {} body bytes", req.body.len());
w.write_all(&req.body)?;
}
w.flush()?;
Ok(())
}
fn read_response<R: Read>(
r: &mut BufReader<R>,
method: &str,
trace: &mut dyn Write,
) -> Result<Response> {
let mut status_line = String::new();
let n = r.read_line(&mut status_line)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
let trimmed_status = status_line.trim_end_matches(['\r', '\n']);
let _ = writeln!(trace, "< {trimmed_status}");
let (version, status, reason) = parse_status_line(trimmed_status)?;
let mut headers: Vec<(String, String)> = Vec::new();
let mut header_bytes = 0usize;
loop {
let mut line = String::new();
let n = r.read_line(&mut line)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
header_bytes += n;
if header_bytes > MAX_HEADER_BYTES {
return Err(Error::BadResponse("headers exceed 64 KiB".into()));
}
let trimmed = line.trim_end_matches(['\r', '\n']);
let _ = writeln!(trace, "< {trimmed}");
if trimmed.is_empty() {
break;
}
let (k, v) = trimmed
.split_once(':')
.ok_or_else(|| Error::BadResponse(format!("malformed header line: {trimmed:?}")))?;
headers.push((k.trim().to_string(), v.trim().to_string()));
}
let body = read_body(r, &headers, &version, status, method)?;
let wire_len = body.len();
let _ = writeln!(trace, "* Received {wire_len} body bytes");
let (headers, body) = maybe_decode_body(headers, body, trace)?;
Ok(Response {
status,
reason,
version,
headers,
body,
})
}
pub(crate) type HeadersAndBody = (Vec<(String, String)>, Vec<u8>);
pub(crate) fn maybe_decode_body(
headers: Vec<(String, String)>,
body: Vec<u8>,
trace: &mut dyn Write,
) -> Result<HeadersAndBody> {
let Some(enc) = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
.map(|(_, v)| v.clone())
else {
return Ok((headers, body));
};
let wire_len = body.len();
let out = crate::compress::decode_body(body, &enc)?;
if out.decoded {
let _ = writeln!(
trace,
"* Decompressed body: {} -> {} bytes ({})",
wire_len,
out.body.len(),
enc
);
Ok((crate::compress::strip_after_decode(headers), out.body))
} else {
Ok((headers, out.body))
}
}
fn parse_status_line(line: &str) -> Result<(String, u16, String)> {
let mut parts = line.splitn(3, ' ');
let version = parts
.next()
.ok_or_else(|| Error::BadResponse(format!("missing version: {line:?}")))?
.to_string();
if !version.starts_with("HTTP/") {
return Err(Error::BadResponse(format!("not HTTP: {version}")));
}
let status: u16 = parts
.next()
.ok_or_else(|| Error::BadResponse(format!("missing status: {line:?}")))?
.parse()
.map_err(|_| Error::BadResponse(format!("bad status: {line:?}")))?;
let reason = parts.next().unwrap_or("").to_string();
Ok((version, status, reason))
}
fn parse_content_length(headers: &[(String, String)]) -> Result<Option<u64>> {
let mut seen: Option<u64> = None;
for (k, v) in headers {
if !k.eq_ignore_ascii_case("content-length") {
continue;
}
for part in v.split(',') {
let n: u64 = part
.trim()
.parse()
.map_err(|_| Error::BadResponse(format!("bad Content-Length: {v:?}")))?;
match seen {
Some(prev) if prev != n => {
return Err(Error::BadResponse(
"conflicting Content-Length values".into(),
));
}
_ => seen = Some(n),
}
}
}
Ok(seen)
}
fn read_body<R: BufRead>(
r: &mut R,
headers: &[(String, String)],
_version: &str,
status: u16,
method: &str,
) -> Result<Vec<u8>> {
if method.eq_ignore_ascii_case("HEAD")
|| (100..200).contains(&status)
|| status == 204
|| status == 304
{
return Ok(Vec::new());
}
let has_te = headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("transfer-encoding"));
let has_cl = headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("content-length"));
if has_te && has_cl {
return Err(Error::BadResponse(
"both Transfer-Encoding and Content-Length present".into(),
));
}
let chunked = headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
if chunked {
return read_chunked(r);
}
let content_length = parse_content_length(headers)?;
let mut body = Vec::new();
match content_length {
Some(len) => {
if len > MAX_BODY_BYTES as u64 {
return Err(Error::BadResponse(format!("body too large: {len}")));
}
body.reserve(len as usize);
r.take(len).read_to_end(&mut body)?;
if (body.len() as u64) < len {
return Err(Error::UnexpectedEof);
}
}
None => {
r.take(MAX_BODY_BYTES as u64).read_to_end(&mut body)?;
}
}
Ok(body)
}
fn read_chunked<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut body = Vec::new();
loop {
let mut size_line = String::new();
let n = r.read_line(&mut size_line)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
let size_str = size_line
.trim_end_matches(['\r', '\n'])
.split(';')
.next()
.unwrap_or("");
let size = usize::from_str_radix(size_str.trim(), 16)
.map_err(|_| Error::BadResponse(format!("bad chunk size: {size_str:?}")))?;
if body.len().saturating_add(size) > MAX_BODY_BYTES {
return Err(Error::BadResponse("body too large".into()));
}
if size == 0 {
loop {
let mut t = String::new();
let n = r.read_line(&mut t)?;
if n == 0 || t.trim_end_matches(['\r', '\n']).is_empty() {
break;
}
}
break;
}
let start = body.len();
body.resize(start + size, 0);
r.read_exact(&mut body[start..])?;
let mut crlf = [0u8; 2];
r.read_exact(&mut crlf)?;
if &crlf != b"\r\n" {
return Err(Error::BadResponse("missing CRLF after chunk".into()));
}
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_status_line_ok() {
let (v, s, r) = parse_status_line("HTTP/1.1 200 OK").unwrap();
assert_eq!(v, "HTTP/1.1");
assert_eq!(s, 200);
assert_eq!(r, "OK");
}
#[test]
fn parses_status_line_no_reason() {
let (_, s, r) = parse_status_line("HTTP/1.0 204").unwrap();
assert_eq!(s, 204);
assert_eq!(r, "");
}
#[test]
fn rejects_non_http() {
assert!(parse_status_line("RTSP/1.0 200 OK").is_err());
}
#[test]
fn header_name_token_validation() {
assert!(is_valid_header_name("X-Custom-Header"));
assert!(is_valid_header_name("Content-Type"));
assert!(!is_valid_header_name(""));
assert!(!is_valid_header_name("Bad Name")); assert!(!is_valid_header_name("Bad:Name")); assert!(!is_valid_header_name("Bad\r\nName"));
}
#[test]
fn validate_header_rejects_crlf_in_value() {
assert!(validate_header("X", "ok").is_ok());
assert!(validate_header("X", "evil\r\nInjected: 1").is_err());
assert!(validate_header("X", "evil\rstuff").is_err());
assert!(validate_header("X", "evil\nstuff").is_err());
assert!(validate_header("X", "evil\0stuff").is_err());
}
#[test]
fn validate_method_rejects_control_and_space() {
assert!(validate_method("GET").is_ok());
assert!(validate_method("PROPFIND").is_ok());
assert!(validate_method("").is_err());
assert!(validate_method("GET HTTP/1.1\r\nEvil:").is_err());
assert!(validate_method("BAD\r\n").is_err());
assert!(validate_method("BAD METHOD").is_err());
}
#[test]
fn write_request_refuses_injected_header() {
let mut req = Request::get("http://example.com/").unwrap();
req.headers
.push(("X-Evil".into(), "a\r\nInjected: 1".into()));
let mut sink = Vec::new();
let mut trace = Vec::new();
let err = write_request(&mut sink, &req, false, &mut trace).unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
assert!(sink.is_empty(), "nothing should have been written");
}
#[test]
fn write_request_refuses_injected_method() {
let mut req = Request::get("http://example.com/").unwrap();
req.method = "GET\r\nEvil: 1".into();
let mut sink = Vec::new();
let mut trace = Vec::new();
assert!(write_request(&mut sink, &req, false, &mut trace).is_err());
assert!(sink.is_empty());
}
#[test]
fn http3_builder_methods_set_pref() {
let r = Request::get("https://example.com/").unwrap().http3();
assert_eq!(r.http_version_pref, HttpVersionPref::Http3);
let r = Request::get("https://example.com/").unwrap().http3_only();
assert_eq!(r.http_version_pref, HttpVersionPref::Http3Only);
let r = Request::get("https://example.com/")
.unwrap()
.http_version(HttpVersionPref::Http3);
assert_eq!(r.http_version_pref, HttpVersionPref::Http3);
}
#[test]
fn h3_fallback_classification() {
use std::io;
assert!(h3_should_fall_back(&Error::Io(io::Error::new(
io::ErrorKind::TimedOut,
"x"
))));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: feed: Decode".into()
)));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: connection closed mid-handshake".into()
)));
assert!(h3_should_fall_back(&Error::BadResponse(
"http3: peer closed connection".into()
)));
assert!(!h3_should_fall_back(&Error::BadResponse(
"qpack: dynamic index out of range".into()
)));
assert!(!h3_should_fall_back(&Error::H2NotNegotiated));
}
#[test]
fn http3_only_over_plaintext_http_errors() {
let req = Request::get("http://example.com/").unwrap().http3_only();
let mut trace = Vec::new();
let err = send_plain(req, &mut trace).unwrap_err();
assert!(matches!(err, Error::UnsupportedScheme(_)));
}
#[test]
fn tls_pool_eligible_only_for_default_posture() {
let mut req = Request::get("https://example.com/").unwrap();
assert!(tls_pool_eligible(&req)); req.verify_tls = false;
assert!(!tls_pool_eligible(&req)); req.verify_tls = true;
req.ca_bundle = Some("/tmp/ca.pem".into());
assert!(!tls_pool_eligible(&req)); }
#[test]
fn content_length_single_ok() {
let h = vec![("Content-Length".to_string(), "42".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), Some(42));
}
#[test]
fn content_length_absent_is_none() {
let h = vec![("X".to_string(), "y".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), None);
}
#[test]
fn content_length_duplicate_agreeing_ok() {
let h = vec![
("Content-Length".to_string(), "5".to_string()),
("content-length".to_string(), "5".to_string()),
];
assert_eq!(parse_content_length(&h).unwrap(), Some(5));
}
#[test]
fn content_length_conflicting_rejected() {
let h = vec![
("Content-Length".to_string(), "5".to_string()),
("Content-Length".to_string(), "6".to_string()),
];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_comma_list_conflicting_rejected() {
let h = vec![("Content-Length".to_string(), "5, 6".to_string())];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn content_length_comma_list_agreeing_ok() {
let h = vec![("Content-Length".to_string(), "5, 5".to_string())];
assert_eq!(parse_content_length(&h).unwrap(), Some(5));
}
#[test]
fn content_length_unparseable_rejected() {
let h = vec![("Content-Length".to_string(), "not-a-number".to_string())];
assert!(parse_content_length(&h).is_err());
}
#[test]
fn read_body_rejects_te_and_cl_together() {
use std::io::Cursor;
let headers = vec![
("Transfer-Encoding".to_string(), "chunked".to_string()),
("Content-Length".to_string(), "5".to_string()),
];
let mut r = BufReader::new(Cursor::new(b"0\r\n\r\n".to_vec()));
let err = read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").unwrap_err();
assert!(matches!(err, Error::BadResponse(_)));
}
#[test]
fn read_body_rejects_conflicting_content_length() {
use std::io::Cursor;
let headers = vec![
("Content-Length".to_string(), "3".to_string()),
("Content-Length".to_string(), "4".to_string()),
];
let mut r = BufReader::new(Cursor::new(b"abcd".to_vec()));
assert!(read_body(&mut r, &headers, "HTTP/1.1", 200, "GET").is_err());
}
#[test]
fn proxy_parse_basic() {
let p = ProxyConfig::parse("http://proxy.example:3128").unwrap();
assert_eq!(p.host, "proxy.example");
assert_eq!(p.port, 3128);
assert!(p.auth.is_none());
}
#[test]
fn proxy_parse_with_creds() {
let p = ProxyConfig::parse("http://alice:hunter2@proxy:8080").unwrap();
assert_eq!(p.host, "proxy");
assert_eq!(p.port, 8080);
assert_eq!(p.auth.as_ref().unwrap().0, "alice");
assert_eq!(p.auth.as_ref().unwrap().1, "hunter2");
}
#[test]
fn proxy_parse_bare_hostport_is_http() {
let p = ProxyConfig::parse("proxy.local:8080").unwrap();
assert_eq!(p.host, "proxy.local");
assert_eq!(p.port, 8080);
}
#[test]
fn proxy_parse_rejects_https() {
let err = ProxyConfig::parse("https://proxy:443").unwrap_err();
matches!(err, Error::UnsupportedScheme(_));
}
#[test]
fn proxy_bypass_matches_suffix() {
let mut req = Request::get("http://api.example.com/x").unwrap();
req.proxy = Some(ProxyConfig::parse("http://proxy:8080").unwrap());
req.no_proxy = vec!["example.com".into()];
assert!(proxy_bypassed(&req));
req.url = Url::parse("http://other.org/x").unwrap();
assert!(!proxy_bypassed(&req));
}
#[test]
fn proxy_bypass_wildcard() {
let mut req = Request::get("http://anywhere/").unwrap();
req.proxy = Some(ProxyConfig::parse("http://p:1").unwrap());
req.no_proxy = vec!["*".into()];
assert!(proxy_bypassed(&req));
}
#[test]
fn connect_tunnel_happy_path() {
use std::io::{self, Cursor};
struct Mock {
written: Vec<u8>,
reply: Cursor<Vec<u8>>,
trailing: Vec<u8>, }
impl Read for Mock {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.reply.read(buf)?;
if n == 0 && !self.trailing.is_empty() {
let take = buf.len().min(self.trailing.len());
buf[..take].copy_from_slice(&self.trailing[..take]);
self.trailing.drain(..take);
return Ok(take);
}
Ok(n)
}
}
impl Write for Mock {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut mock = Mock {
written: Vec::new(),
reply: Cursor::new(b"HTTP/1.1 200 Connection established\r\n\r\n".to_vec()),
trailing: vec![0x16],
};
let target = Url::parse("https://origin.example:443/").unwrap();
let proxy = ProxyConfig {
host: "proxy".into(),
port: 3128,
auth: Some(("u".into(), "p".into())),
};
connect_tunnel(&mut mock, &target, &proxy, &mut io::sink()).unwrap();
let written = String::from_utf8(mock.written.clone()).unwrap();
assert!(
written.starts_with("CONNECT origin.example:443 HTTP/1.1\r\n"),
"request line missing: {written:?}",
);
assert!(
written.contains("Host: origin.example:443\r\n"),
"Host header missing: {written:?}",
);
assert!(
written.contains("Proxy-Authorization: Basic dTpw\r\n"),
"auth header missing or wrong: {written:?}",
);
let mut byte = [0u8; 1];
assert_eq!(mock.read(&mut byte).unwrap(), 1);
assert_eq!(byte[0], 0x16, "next-layer byte was consumed by the tunnel");
}
#[test]
fn connect_tunnel_reports_407() {
use std::io::{self, Cursor};
let payload =
b"HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic\r\n\r\n";
let mut mock = std::io::Cursor::new(Vec::new());
struct RW<'a> {
inner: Cursor<&'a [u8]>,
sink: &'a mut Vec<u8>,
}
impl<'a> Read for RW<'a> {
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
self.inner.read(b)
}
}
impl<'a> Write for RW<'a> {
fn write(&mut self, b: &[u8]) -> io::Result<usize> {
self.sink.extend_from_slice(b);
Ok(b.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut sink = Vec::new();
let mut rw = RW {
inner: Cursor::new(payload),
sink: &mut sink,
};
let target = Url::parse("https://origin/").unwrap();
let proxy = ProxyConfig {
host: "p".into(),
port: 1,
auth: None,
};
let err = connect_tunnel(&mut rw, &target, &proxy, &mut io::sink()).unwrap_err();
match err {
Error::BadResponse(msg) => assert!(msg.contains("407"), "got {msg:?}"),
other => panic!("unexpected: {other:?}"),
}
let _ = mock.write(&[]);
}
}