use crate::{
Body, abi,
body::BodyInner,
callback::{
CompletionSignal, SignalCancelled, await_win32, borrow_context_ptr, leak_context_ptr,
reclaim_context_ptr,
},
error::{ContextError, Error},
proxy::{ProxyAction, ProxyConfig},
redirect::{self, Policy},
url::Url,
util::lock_or_clear,
};
use bytes::BytesMut;
use http::{StatusCode, Version};
use std::sync::{
Arc, Mutex,
atomic::{AtomicU32, Ordering},
};
use windows_sys::Win32::Networking::WinHttp::*;
trait ResultUrlExt<T> {
fn url_context(self, url: &Url) -> Result<T, Error>;
}
impl<T> ResultUrlExt<T> for Result<T, Error> {
fn url_context(self, url: &Url) -> Result<T, Error> {
self.map_err(|e| e.with_url(url.clone()))
}
}
#[derive(Clone, Copy)]
struct SendPtr(usize);
impl SendPtr {
fn as_mut_ptr(self) -> *mut core::ffi::c_void {
self.0 as *mut core::ffi::c_void
}
}
pub(crate) struct WinHttpHandle(pub *mut core::ffi::c_void);
impl WinHttpHandle {
fn as_send(&self) -> SendPtr {
SendPtr(self.0 as usize)
}
}
impl Drop for WinHttpHandle {
fn drop(&mut self) {
abi::close_winhttp_handle(self.0);
}
}
unsafe impl Send for WinHttpHandle {}
unsafe impl Sync for WinHttpHandle {}
#[derive(Debug)]
pub(crate) enum CallbackEvent {
Complete,
ReadComplete(u32),
WriteComplete(u32),
Win32Error(u32),
}
impl CallbackEvent {
fn unexpected(self, url: &Url) -> Error {
Error::request(format!("unexpected callback event: {self:?}")).with_url(url.clone())
}
pub fn into_result(self, state: &RequestState, url: &Url) -> Result<(), Error> {
match self {
CallbackEvent::Complete => Ok(()),
CallbackEvent::Win32Error(code) => Err(callback_error_to_error(code, state, url)),
other => Err(other.unexpected(url)),
}
}
pub fn into_read_complete(self, url: &Url) -> Result<u32, Error> {
match self {
CallbackEvent::ReadComplete(n) => Ok(n),
CallbackEvent::Win32Error(code) => Err(Error::from_win32(code).with_url(url.clone())),
other => Err(other.unexpected(url)),
}
}
pub fn into_write_complete(self, url: &Url) -> Result<u32, Error> {
match self {
CallbackEvent::WriteComplete(n) => Ok(n),
CallbackEvent::Win32Error(code) => Err(Error::from_win32(code).with_url(url.clone())),
other => Err(other.unexpected(url)),
}
}
}
pub(crate) struct RequestState {
pub signal: CompletionSignal<CallbackEvent>,
#[cfg_attr(not(feature = "tracing"), expect(dead_code))]
pub verbose: bool,
pub tls_failure_flags: AtomicU32,
pub read_buffer: Mutex<Option<BytesMut>>,
pub send_body: Mutex<Option<bytes::Bytes>>,
}
impl RequestState {
pub fn new(verbose: bool) -> Self {
Self {
signal: CompletionSignal::new(),
verbose,
tls_failure_flags: AtomicU32::new(0),
read_buffer: Mutex::new(None),
send_body: Mutex::new(None),
}
}
}
unsafe impl Send for RequestState {}
unsafe impl Sync for RequestState {}
pub(crate) unsafe extern "system" fn winhttp_callback(
_hinternet: *mut core::ffi::c_void,
dw_context: usize,
dw_status: u32,
lpv_info: *mut std::ffi::c_void,
dw_info_length: u32,
) {
if dw_context == 0 {
return;
}
let state: &RequestState = unsafe { borrow_context_ptr(dw_context) };
match dw_status {
WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE => {
state.signal.signal(CallbackEvent::Complete);
}
WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE => {
state.signal.signal(CallbackEvent::Complete);
}
WINHTTP_CALLBACK_STATUS_READ_COMPLETE => {
state
.signal
.signal(CallbackEvent::ReadComplete(dw_info_length));
}
WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE => {
let bytes = if !lpv_info.is_null() && dw_info_length >= 4 {
unsafe { *(lpv_info as *const u32) }
} else {
0
};
state.signal.signal(CallbackEvent::WriteComplete(bytes));
}
WINHTTP_CALLBACK_STATUS_REQUEST_ERROR => {
let result = unsafe { &*(lpv_info as *const WINHTTP_ASYNC_RESULT) };
state
.signal
.signal(CallbackEvent::Win32Error(result.dwError));
}
WINHTTP_CALLBACK_STATUS_SECURE_FAILURE => {
let flags = unsafe { *(lpv_info as *const u32) };
state.tls_failure_flags.store(flags, Ordering::Release);
}
WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING => unsafe {
reclaim_context_ptr::<RequestState>(dw_context);
},
#[cfg(feature = "tracing")]
status => {
if state.verbose {
log_verbose_status(status, lpv_info, dw_info_length);
}
}
#[cfg(not(feature = "tracing"))]
_ => {}
}
}
#[cfg(feature = "tracing")]
fn log_verbose_status(status: u32, info: *mut std::ffi::c_void, info_len: u32) {
match status {
WINHTTP_CALLBACK_STATUS_RESOLVING_NAME => {
let name = unsafe { crate::util::wide_to_string_lossy(info, info_len) };
trace!(name = %name, "WinHTTP: resolving name");
}
WINHTTP_CALLBACK_STATUS_NAME_RESOLVED => {
let name = unsafe { crate::util::wide_to_string_lossy(info, info_len) };
trace!(name = %name, "WinHTTP: name resolved");
}
WINHTTP_CALLBACK_STATUS_CONNECTING_TO_SERVER => {
let ip = unsafe { crate::util::wide_to_string_lossy(info, info_len) };
trace!(ip = %ip, "WinHTTP: connecting to server");
}
WINHTTP_CALLBACK_STATUS_CONNECTED_TO_SERVER => {
let ip = unsafe { crate::util::wide_to_string_lossy(info, info_len) };
trace!(ip = %ip, "WinHTTP: connected to server");
}
WINHTTP_CALLBACK_STATUS_SENDING_REQUEST => {
trace!("WinHTTP: sending request");
}
WINHTTP_CALLBACK_STATUS_REQUEST_SENT => {
let bytes = if !info.is_null() && info_len >= 4 {
unsafe { *(info as *const u32) }
} else {
0
};
trace!(bytes = bytes, "WinHTTP: request sent");
}
WINHTTP_CALLBACK_STATUS_RECEIVING_RESPONSE => {
trace!("WinHTTP: receiving response");
}
WINHTTP_CALLBACK_STATUS_RESPONSE_RECEIVED => {
let bytes = if !info.is_null() && info_len >= 4 {
unsafe { *(info as *const u32) }
} else {
0
};
trace!(bytes = bytes, "WinHTTP: response received");
}
WINHTTP_CALLBACK_STATUS_REDIRECT => {
let url = unsafe { crate::util::wide_to_string_lossy(info, info_len) };
trace!(url = %url, "WinHTTP: redirect");
}
_ => {}
}
}
pub(crate) struct SessionConfig {
pub user_agent: String,
pub connect_timeout_ms: u32,
pub send_timeout_ms: u32,
pub read_timeout_ms: u32,
pub verbose: bool,
pub max_connections_per_host: Option<u32>,
pub proxy: ProxyAction,
pub redirect_policy: Option<Policy>,
pub http1_only: bool,
}
pub(crate) struct WinHttpSession {
pub handle: WinHttpHandle,
pub verbose: bool,
}
impl WinHttpSession {
pub fn open(config: &SessionConfig) -> Result<Self, Error> {
let (access_type, proxy_str) = match &config.proxy {
ProxyAction::Direct => (WINHTTP_ACCESS_TYPE_NO_PROXY, None),
ProxyAction::Named(url, _) => (WINHTTP_ACCESS_TYPE_NAMED_PROXY, Some(url.as_str())),
ProxyAction::Automatic => (WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY, None),
};
let h_session = abi::winhttp_open_session(
&config.user_agent,
access_type,
proxy_str,
WINHTTP_FLAG_ASYNC,
)?;
let session = WinHttpHandle(h_session);
abi::winhttp_set_status_callback(
session.0,
Some(winhttp_callback),
WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS,
)?;
abi::winhttp_set_timeouts(
session.0,
0, config.connect_timeout_ms as i32,
config.send_timeout_ms as i32,
config.read_timeout_ms as i32,
)?;
if !config.http1_only {
let _ = abi::winhttp_set_option_u32(
session.0,
WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL,
WINHTTP_PROTOCOL_FLAG_HTTP2,
);
}
let _ = abi::winhttp_set_option_u32(
session.0,
WINHTTP_OPTION_DECOMPRESSION,
WINHTTP_DECOMPRESSION_FLAG_GZIP | WINHTTP_DECOMPRESSION_FLAG_DEFLATE,
);
abi::winhttp_set_option_u32(session.0, WINHTTP_OPTION_ASSURED_NON_BLOCKING_CALLBACKS, 1)?;
if let Some(max_conns) = config.max_connections_per_host {
abi::winhttp_set_option_u32(session.0, WINHTTP_OPTION_MAX_CONNS_PER_SERVER, max_conns)?;
}
match &config.redirect_policy {
Some(policy) => match &policy.inner {
redirect::PolicyInner::None => {
abi::winhttp_set_option_u32(
session.0,
WINHTTP_OPTION_REDIRECT_POLICY,
WINHTTP_OPTION_REDIRECT_POLICY_NEVER,
)?;
}
redirect::PolicyInner::Limited(max) => {
abi::winhttp_set_option_u32(
session.0,
WINHTTP_OPTION_MAX_HTTP_AUTOMATIC_REDIRECTS,
*max,
)?;
}
},
None => {
abi::winhttp_set_option_u32(
session.0,
WINHTTP_OPTION_MAX_HTTP_AUTOMATIC_REDIRECTS,
10,
)?;
}
}
Ok(Self {
handle: session,
verbose: config.verbose,
})
}
}
pub(crate) struct RawResponse {
pub request_handle: WinHttpHandle,
pub state: Arc<RequestState>,
pub status: StatusCode,
pub version: Version,
pub url: Url,
pub headers: http::HeaderMap,
}
pub(crate) async fn execute_request(
session: &WinHttpSession,
url: &Url,
method: &str,
headers: &[(String, String)],
body: Option<Body>,
proxy_config: &ProxyConfig,
accept_invalid_certs: bool,
) -> Result<RawResponse, Error> {
let per_request_proxy = proxy_config.resolve(&url.host, url.is_https);
trace!(
url = %url,
proxy = ?per_request_proxy,
"proxy resolved for request",
);
let state = Arc::new(RequestState::new(session.verbose));
let body_inner = body.map(|b| b.into_inner());
let (body_ptr, body_len, has_bytes_body, mut stream) = match body_inner {
Some(BodyInner::Bytes(v)) => {
if v.is_empty() {
let mut guard = lock_or_clear(&state.send_body);
*guard = Some(v);
(0usize, 0u64, false, None)
} else {
let mut guard = lock_or_clear(&state.send_body);
let stored = guard.insert(v);
let ptr = stored.as_ptr() as usize;
let len = stored.len() as u64;
(ptr, len, true, None)
}
}
Some(BodyInner::Stream(s)) => (0usize, 0u64, false, Some(s)),
None => (0usize, 0u64, false, None),
};
let h_connect = abi::winhttp_connect(session.handle.0, &url.host, url.port).url_context(url)?;
let _connect_handle = WinHttpHandle(h_connect);
let h_request = abi::winhttp_open_request(h_connect, method, &url.path_and_query, url.is_https)
.url_context(url)?;
let _ = h_connect;
let request_handle = WinHttpHandle(h_request);
let _ = h_request;
let ctx = leak_context_ptr(&state);
if let Err(e) =
abi::winhttp_set_option_usize(request_handle.0, WINHTTP_OPTION_CONTEXT_VALUE, ctx)
{
unsafe {
reclaim_context_ptr::<RequestState>(ctx);
}
return Err(e.with_url(url.clone()));
}
match &per_request_proxy {
ProxyAction::Direct => {
abi::winhttp_set_proxy_direct(request_handle.0).url_context(url)?;
}
ProxyAction::Named(proxy_url, proxy_creds) => {
abi::winhttp_set_proxy_named(request_handle.0, proxy_url).url_context(url)?;
if let Some((username, password)) = proxy_creds {
abi::winhttp_set_proxy_credentials(request_handle.0, username, password)
.url_context(url)?;
}
}
ProxyAction::Automatic => {
}
}
if accept_invalid_certs && url.is_https {
let security_flags: u32 = SECURITY_FLAG_IGNORE_UNKNOWN_CA
| SECURITY_FLAG_IGNORE_CERT_DATE_INVALID
| SECURITY_FLAG_IGNORE_CERT_CN_INVALID
| SECURITY_FLAG_IGNORE_CERT_WRONG_USAGE;
abi::winhttp_set_option_u32(
request_handle.0,
WINHTTP_OPTION_SECURITY_FLAGS,
security_flags,
)
.url_context(url)?;
}
for (name, value) in headers {
let header_line = format!("{name}: {value}\r\n");
abi::winhttp_add_request_header(request_handle.0, &header_line).url_context(url)?;
}
#[cfg(not(test))]
const LARGE_BODY_THRESHOLD: u64 = u32::MAX as u64;
#[cfg(test)]
const LARGE_BODY_THRESHOLD: u64 = 4 * 1024 * 1024;
#[cfg(not(test))]
const LARGE_BODY_CHUNK_MAX: usize = u32::MAX as usize;
#[cfg(test)]
const LARGE_BODY_CHUNK_MAX: usize = 2 * 1024 * 1024;
if let Some(ref mut stream) = stream {
trace!("body path: streaming (chunked transfer encoding)");
abi::winhttp_add_request_header(request_handle.0, "Transfer-Encoding: chunked\r\n")
.url_context(url)?;
let h_send = request_handle.as_send();
await_win32(&state.signal, move || {
abi::winhttp_send_request(
h_send.as_mut_ptr(),
std::ptr::null(),
0,
WINHTTP_IGNORE_REQUEST_TOTAL_LENGTH,
)
.url_context(url)
})
.await?
.into_result(&state, url)?;
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| {
Error::request(ContextError::new("stream body error", e)).with_url(url.clone())
})?;
if chunk.is_empty() {
continue;
}
let header = format!("{:x}\r\n", chunk.len());
let mut frame = Vec::with_capacity(header.len() + chunk.len() + 2);
frame.extend_from_slice(header.as_bytes());
frame.extend_from_slice(&chunk);
frame.extend_from_slice(b"\r\n");
let (frame_ptr, frame_len) = {
let mut guard = lock_or_clear(&state.send_body);
let stored = guard.insert(frame.into());
(stored.as_ptr() as usize, stored.len() as u32)
};
write_data(&state.signal, &request_handle, frame_ptr, frame_len, url).await?;
}
{
let terminator = b"0\r\n\r\n".to_vec();
let (term_ptr, term_len) = {
let mut guard = lock_or_clear(&state.send_body);
let stored = guard.insert(terminator.into());
(stored.as_ptr() as usize, stored.len() as u32)
};
write_data(&state.signal, &request_handle, term_ptr, term_len, url).await?;
}
} else if body_len <= LARGE_BODY_THRESHOLD {
trace!(body_len, "body path: inline");
let inline_len = body_len as u32;
let h_send = request_handle.as_send();
let body_ptr_usize = body_ptr;
await_win32(&state.signal, move || {
let optional = if inline_len > 0 {
body_ptr_usize as *const std::ffi::c_void
} else {
std::ptr::null()
};
abi::winhttp_send_request(h_send.as_mut_ptr(), optional, inline_len, inline_len)
.url_context(url)
})
.await?
.into_result(&state, url)?;
} else {
trace!(body_len, "body path: large (multi-write)");
abi::winhttp_add_request_header(
request_handle.0,
&format!("Content-Length: {body_len}\r\n"),
)
.url_context(url)?;
let h_send = request_handle.as_send();
await_win32(&state.signal, move || {
abi::winhttp_send_request(
h_send.as_mut_ptr(),
std::ptr::null(),
0,
WINHTTP_IGNORE_REQUEST_TOTAL_LENGTH,
)
.url_context(url)
})
.await?
.into_result(&state, url)?;
if has_bytes_body {
let total_len = body_len as usize;
let chunk_max = LARGE_BODY_CHUNK_MAX;
let mut offset: usize = 0;
while offset < total_len {
let remaining = total_len - offset;
let chunk_size = remaining.min(chunk_max);
let chunk_len = chunk_size as u32;
let body_offset = offset;
write_data(&state.signal, &request_handle, body_ptr + body_offset, chunk_len, url)
.await?;
offset += chunk_size;
}
}
}
let h_recv = request_handle.as_send();
await_win32(&state.signal, move || {
abi::winhttp_receive_response(h_recv.as_mut_ptr()).url_context(url)
})
.await?
.into_result(&state, url)?;
let _ = lock_or_clear(&state.send_body).take();
let status = query_status_code(request_handle.0, url)?;
let version = query_version(request_handle.0);
let headers = query_headers(request_handle.0, url)?;
let final_url = abi::winhttp_query_option_url(request_handle.0, WINHTTP_OPTION_URL)
.and_then(|s| Url::parse(&s).ok())
.unwrap_or_else(|| url.clone());
trace!(
status = status.as_u16(),
version = ?version,
final_url = %final_url,
header_count = headers.len(),
"headers received",
);
Ok(RawResponse {
request_handle,
state,
status,
version,
url: final_url,
headers,
})
}
async fn write_data(
signal: &CompletionSignal<CallbackEvent>,
handle: &WinHttpHandle,
data_ptr: usize,
data_len: u32,
url: &Url,
) -> Result<u32, Error> {
let h = handle.as_send();
await_win32(signal, move || {
let ptr = data_ptr as *const std::ffi::c_void;
abi::winhttp_write_data(h.as_mut_ptr(), ptr, data_len).url_context(url)
})
.await?
.into_write_complete(url)
}
pub(crate) async fn read_chunk(
state: &Arc<RequestState>,
handle: &WinHttpHandle,
url: &Url,
) -> Result<Option<bytes::Bytes>, Error> {
const READ_BUF_SIZE: usize = 8192;
let buf = BytesMut::with_capacity(READ_BUF_SIZE);
let h_read = handle.as_send();
let read = await_win32(&state.signal, move || {
let (buf_ptr, buf_capacity) = {
let mut guard = lock_or_clear(&state.read_buffer);
let buf_ref = guard.insert(buf);
let spare = buf_ref.spare_capacity_mut();
(spare.as_ptr() as *mut std::ffi::c_void, spare.len() as u32)
};
abi::winhttp_read_data(h_read.as_mut_ptr(), buf_ptr, buf_capacity).url_context(url)
})
.await?
.into_read_complete(url)?;
if read == 0 {
lock_or_clear(&state.read_buffer).take();
return Ok(None);
}
let mut guard = lock_or_clear(&state.read_buffer);
let Some(mut buf) = guard.take() else {
return Err(Error::request("read buffer missing after read (invariant violated)")
.with_url(url.clone()));
};
if (read as usize) > buf.capacity() {
Err(Error::request(format!(
"WinHTTP reported {read} bytes read but buffer capacity is {} (invariant violated)",
buf.capacity(),
))
.with_url(url.clone()))
} else {
unsafe {
buf.set_len(read as usize);
}
Ok(Some(buf.freeze()))
}
}
fn query_headers(h_request: *mut core::ffi::c_void, url: &Url) -> Result<http::HeaderMap, Error> {
let raw = abi::winhttp_query_raw_headers(h_request).url_context(url)?;
Ok(parse_raw_headers(&raw))
}
fn parse_raw_headers(raw: &str) -> http::HeaderMap {
let mut headers = http::HeaderMap::new();
for line in raw.lines() {
if line.is_empty() || line.starts_with("HTTP/") {
continue;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
if let (Ok(n), Ok(v)) = (
http::header::HeaderName::from_bytes(name.as_bytes()),
http::header::HeaderValue::from_bytes(value.as_bytes()),
) {
headers.append(n, v);
}
}
}
headers
}
fn query_status_code(h_request: *mut core::ffi::c_void, url: &Url) -> Result<StatusCode, Error> {
let status_code = abi::winhttp_query_header_u32(
h_request,
WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER,
)
.url_context(url)?;
StatusCode::from_u16(status_code as u16).map_err(|e| {
Error::request(ContextError::new(format!("invalid status code: {status_code}"), e))
.with_url(url.clone())
})
}
fn query_version(h_request: *mut core::ffi::c_void) -> Version {
let protocol_flags =
abi::winhttp_query_option_u32(h_request, WINHTTP_OPTION_HTTP_PROTOCOL_USED);
let version_str = abi::winhttp_query_header_string(h_request, WINHTTP_QUERY_VERSION);
resolve_version(protocol_flags, version_str.as_deref())
}
fn resolve_version(protocol_flags: Option<u32>, version_str: Option<&str>) -> Version {
if let Some(flags) = protocol_flags {
if flags & WINHTTP_PROTOCOL_FLAG_HTTP3 != 0 {
return Version::HTTP_3;
}
if flags & WINHTTP_PROTOCOL_FLAG_HTTP2 != 0 {
return Version::HTTP_2;
}
}
if let Some(s) = version_str {
match s {
"HTTP/1.0" => return Version::HTTP_10,
"HTTP/1.1" => return Version::HTTP_11,
_ => {}
}
}
Version::HTTP_11 }
fn callback_error_to_error(code: u32, state: &RequestState, url: &Url) -> Error {
let mut err = Error::from_win32(code);
err.inner.url = Some(Box::new(url.clone()));
if code == ERROR_WINHTTP_SECURE_FAILURE {
let tls_flags = state.tls_failure_flags.load(Ordering::Acquire);
let detail = describe_tls_failure(tls_flags);
if let Some(source) = err.inner.source.take() {
err.inner.source =
Some(Box::new(ContextError::new(format!("TLS error: {detail}"), source)));
}
}
err
}
fn describe_tls_failure(flags: u32) -> String {
let mut parts = Vec::new();
if flags & WINHTTP_CALLBACK_STATUS_FLAG_CERT_REV_FAILED != 0 {
parts.push("revocation check failed");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CERT != 0 {
parts.push("invalid certificate");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_CERT_REVOKED != 0 {
parts.push("certificate revoked");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CA != 0 {
parts.push("invalid CA");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_CERT_CN_INVALID != 0 {
parts.push("certificate CN mismatch");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_CERT_DATE_INVALID != 0 {
parts.push("certificate expired or not yet valid");
}
if flags & WINHTTP_CALLBACK_STATUS_FLAG_SECURITY_CHANNEL_ERROR != 0 {
parts.push("security channel error");
}
if parts.is_empty() {
"unknown TLS failure".to_owned()
} else {
parts.join(", ")
}
}
impl From<SignalCancelled> for Error {
fn from(sc: SignalCancelled) -> Self {
Error::request(sc)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn large_body_multi_write_path() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/large-ut"))
.respond_with(ResponseTemplate::new(200).set_body_string("ok"))
.expect(1)
.mount(&server)
.await;
let config = SessionConfig {
user_agent: String::new(),
connect_timeout_ms: 10_000,
send_timeout_ms: 0,
read_timeout_ms: 0,
verbose: false,
max_connections_per_host: None,
proxy: ProxyAction::Automatic,
redirect_policy: None,
http1_only: false,
};
let session = WinHttpSession::open(&config).expect("session should open");
let url: Url = format!("{}/large-ut", server.uri()).parse().unwrap();
let proxy_config = ProxyConfig::none();
let body = Body::from(vec![b'X'; 5 * 1024 * 1024]);
let raw = execute_request(&session, &url, "POST", &[], Some(body), &proxy_config, false)
.await
.expect("large body request should succeed");
assert_eq!(raw.status, 200);
}
#[tokio::test]
async fn session_config_variants() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
struct Case {
label: &'static str,
proxy: ProxyAction,
redirect_policy: Option<Policy>,
src_path: &'static str,
redirect_to: Option<&'static str>,
dst_path: Option<&'static str>,
expected_status: u16,
}
let cases = [
Case {
label: "ProxyAction::Direct",
proxy: ProxyAction::Direct,
redirect_policy: None,
src_path: "/direct-test",
redirect_to: None,
dst_path: None,
expected_status: 200,
},
Case {
label: "Policy::none() → 302 returned as-is",
proxy: ProxyAction::Automatic,
redirect_policy: Some(Policy::none()),
src_path: "/rp-src",
redirect_to: Some("/rp-dst"),
dst_path: None, expected_status: 302,
},
Case {
label: "Policy::limited(5) → redirect followed",
proxy: ProxyAction::Automatic,
redirect_policy: Some(Policy::limited(5)),
src_path: "/lim-src",
redirect_to: Some("/lim-dst"),
dst_path: Some("/lim-dst"),
expected_status: 200,
},
];
for case in cases {
let server = MockServer::start().await;
if let Some(redir) = case.redirect_to {
Mock::given(method("GET"))
.and(path(case.src_path))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", format!("{}{redir}", server.uri())),
)
.expect(1)
.mount(&server)
.await;
} else {
Mock::given(method("GET"))
.and(path(case.src_path))
.respond_with(ResponseTemplate::new(200).set_body_string("ok"))
.expect(1)
.mount(&server)
.await;
}
if let Some(dst) = case.dst_path {
Mock::given(method("GET"))
.and(path(dst))
.respond_with(ResponseTemplate::new(200).set_body_string("arrived"))
.expect(1)
.mount(&server)
.await;
}
let config = SessionConfig {
user_agent: String::new(),
connect_timeout_ms: 10_000,
send_timeout_ms: 0,
read_timeout_ms: 0,
verbose: false,
max_connections_per_host: None,
proxy: case.proxy,
redirect_policy: case.redirect_policy,
http1_only: false,
};
let session = WinHttpSession::open(&config)
.unwrap_or_else(|e| panic!("{}: session open failed: {e}", case.label));
let url: Url = format!("{}{}", server.uri(), case.src_path)
.parse()
.unwrap();
let proxy_config = ProxyConfig::none();
let raw = execute_request(&session, &url, "GET", &[], None, &proxy_config, false)
.await
.unwrap_or_else(|e| panic!("{}: request failed: {e}", case.label));
assert_eq!(raw.status, case.expected_status, "{}", case.label);
}
}
#[tokio::test]
async fn per_request_proxy_direct_via_no_proxy() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/np-direct"))
.respond_with(ResponseTemplate::new(200).set_body_string("bypassed"))
.expect(1)
.mount(&server)
.await;
let config = SessionConfig {
user_agent: String::new(),
connect_timeout_ms: 10_000,
send_timeout_ms: 0,
read_timeout_ms: 0,
verbose: false,
max_connections_per_host: None,
proxy: ProxyAction::Named(server.uri(), None),
redirect_policy: None,
http1_only: false,
};
let session = WinHttpSession::open(&config).expect("session should open");
let url: Url = format!("{}/np-direct", server.uri()).parse().unwrap();
let mut proxy_config = ProxyConfig::none();
crate::NoProxy::from_string("127.0.0.1")
.unwrap()
.apply_to(&mut proxy_config);
let raw = execute_request(&session, &url, "GET", &[], None, &proxy_config, false)
.await
.expect("direct bypass request should succeed");
assert_eq!(raw.status, 200);
}
#[test]
fn describe_tls_failure_table() {
let cases: &[(u32, &[&str])] = &[
(0, &["unknown TLS failure"]),
(WINHTTP_CALLBACK_STATUS_FLAG_CERT_REV_FAILED, &["revocation check failed"]),
(WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CERT, &["invalid certificate"]),
(WINHTTP_CALLBACK_STATUS_FLAG_CERT_REVOKED, &["certificate revoked"]),
(WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CA, &["invalid CA"]),
(WINHTTP_CALLBACK_STATUS_FLAG_CERT_CN_INVALID, &["certificate CN mismatch"]),
(
WINHTTP_CALLBACK_STATUS_FLAG_CERT_DATE_INVALID,
&["certificate expired or not yet valid"],
),
(WINHTTP_CALLBACK_STATUS_FLAG_SECURITY_CHANNEL_ERROR, &["security channel error"]),
(
WINHTTP_CALLBACK_STATUS_FLAG_CERT_REVOKED
| WINHTTP_CALLBACK_STATUS_FLAG_CERT_DATE_INVALID,
&["certificate revoked", "certificate expired or not yet valid"],
),
(
WINHTTP_CALLBACK_STATUS_FLAG_CERT_REV_FAILED
| WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CERT
| WINHTTP_CALLBACK_STATUS_FLAG_CERT_REVOKED
| WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CA
| WINHTTP_CALLBACK_STATUS_FLAG_CERT_CN_INVALID
| WINHTTP_CALLBACK_STATUS_FLAG_CERT_DATE_INVALID
| WINHTTP_CALLBACK_STATUS_FLAG_SECURITY_CHANNEL_ERROR,
&[
"revocation check failed",
"invalid certificate",
"certificate revoked",
"invalid CA",
"certificate CN mismatch",
"certificate expired or not yet valid",
"security channel error",
],
),
];
for &(flags, expected) in cases {
let s = describe_tls_failure(flags);
for needle in expected {
assert!(s.contains(needle), "flags 0x{flags:X}: expected {needle:?}, got: {s}");
}
}
}
#[test]
fn callback_event_into_result() {
let url: Url = "https://example.com".parse().unwrap();
let state = RequestState::new(false);
type TestCase = (CallbackEvent, Result<(), fn(&Error) -> bool>);
let cases: Vec<TestCase> = vec![
(CallbackEvent::Complete, Ok(())),
(CallbackEvent::Win32Error(ERROR_WINHTTP_TIMEOUT), Err(Error::is_timeout)),
(CallbackEvent::ReadComplete(42), Err(Error::is_request)),
(CallbackEvent::WriteComplete(0), Err(Error::is_request)),
];
for (event, expected) in cases {
let label = format!("{event:?}");
let result = event.into_result(&state, &url);
match expected {
Ok(()) => assert!(result.is_ok(), "{label}: expected Ok"),
Err(check) => {
let err = result.expect_err(&format!("{label}: expected Err"));
assert!(check(&err), "{label}: wrong error kind: {err}");
}
}
}
}
#[test]
fn callback_event_into_read_write_complete() {
let url: Url = "https://example.com".parse().unwrap();
type TestCase<'a> = (
&'a str,
fn(CallbackEvent, &Url) -> crate::Result<u32>,
CallbackEvent,
u32,
CallbackEvent,
);
let cases: Vec<TestCase<'_>> = vec![
(
"into_read_complete",
|e, u| e.into_read_complete(u),
CallbackEvent::ReadComplete(512),
512,
CallbackEvent::WriteComplete(0),
),
(
"into_write_complete",
|e, u| e.into_write_complete(u),
CallbackEvent::WriteComplete(256),
256,
CallbackEvent::ReadComplete(0),
),
];
for (label, method, happy_event, expected_val, wrong_event) in cases {
assert_eq!(method(happy_event, &url).unwrap(), expected_val, "{label}: happy");
let err = method(wrong_event, &url).unwrap_err();
assert!(err.is_request(), "{label}: wrong variant should be request error");
let err = method(CallbackEvent::Win32Error(ERROR_WINHTTP_TIMEOUT), &url).unwrap_err();
assert!(err.is_timeout(), "{label}: timeout variant");
}
}
#[test]
fn signal_cancelled_into_error() {
let err: Error = SignalCancelled.into();
assert!(err.is_request());
assert_eq!(err.to_string(), "error sending request");
let source = std::error::Error::source(&err).expect("should have source");
assert!(source.to_string().contains("cancelled"));
}
#[test]
fn callback_error_to_error_preserves_url() {
let url: Url = "https://example.com/test".parse().unwrap();
let state = RequestState::new(false);
let err = callback_error_to_error(ERROR_WINHTTP_TIMEOUT, &state, &url);
assert!(err.is_timeout());
assert_eq!(err.url().map(|u| u.as_str()), Some("https://example.com/test"));
}
#[test]
fn tls_failure_enrichment() {
let url: Url = "https://example.com".parse().unwrap();
let state = RequestState::new(false);
state
.tls_failure_flags
.store(WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CA, std::sync::atomic::Ordering::Release);
let err = callback_error_to_error(ERROR_WINHTTP_SECURE_FAILURE, &state, &url);
assert!(err.is_connect());
assert_eq!(err.to_string(), "error trying to connect for url (https://example.com/)");
let debug = format!("{err:?}");
assert!(
debug.contains("invalid CA"),
"TLS error should be enriched with failure details in debug, got: {debug}"
);
}
#[test]
fn parse_raw_headers_table() {
type TestCase<'a> = (&'a str, &'a [(&'a str, &'a str)], &'a str);
let cases: &[TestCase] = &[
("", &[], "empty input"),
("HTTP/1.1 200 OK\r\n", &[], "status line only"),
(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 42\r\n",
&[("content-type", "text/html"), ("content-length", "42")],
"typical response",
),
(
"HTTP/1.1 200 OK\r\nLocation: https://example.com:8080/path\r\n",
&[("location", "https://example.com:8080/path")],
"colon in value",
),
(
"HTTP/1.1 200 OK\r\nmalformed-line-without-colon\r\nContent-Type: text/plain\r\n",
&[("content-type", "text/plain")],
"no-colon line skipped",
),
(
"HTTP/1.1 200 OK\r\n X-Custom : value with spaces \r\n",
&[("x-custom", "value with spaces")],
"whitespace trimmed",
),
];
for &(raw, expected, label) in cases {
let headers = parse_raw_headers(raw);
assert_eq!(headers.len(), expected.len(), "{label}: header count");
for &(name, value) in expected {
assert_eq!(
headers
.get(name)
.unwrap_or_else(|| panic!("{label}: missing {name}")),
value,
"{label}: {name}"
);
}
}
}
#[test]
fn parse_raw_headers_duplicate_headers() {
let raw = "HTTP/1.1 200 OK\r\nSet-Cookie: a=1\r\nSet-Cookie: b=2\r\n";
let headers = parse_raw_headers(raw);
let cookies: Vec<&str> = headers
.get_all("set-cookie")
.iter()
.map(|v| v.to_str().unwrap())
.collect();
assert_eq!(cookies.len(), 2);
assert!(cookies.contains(&"a=1"));
assert!(cookies.contains(&"b=2"));
}
#[test]
fn resolve_version_table() {
let cases: &[(Option<u32>, Option<&str>, Version, &str)] = &[
(None, None, Version::HTTP_11, "no info defaults to HTTP/1.1"),
(None, Some("HTTP/1.0"), Version::HTTP_10, "version string HTTP/1.0"),
(None, Some("HTTP/1.1"), Version::HTTP_11, "version string HTTP/1.1"),
(None, Some("HTTP/2.0"), Version::HTTP_11, "unrecognized version string defaults"),
(Some(0), None, Version::HTTP_11, "flags zero defaults to HTTP/1.1"),
(Some(0), Some("HTTP/1.0"), Version::HTTP_10, "flags zero falls through to string"),
(Some(WINHTTP_PROTOCOL_FLAG_HTTP2), None, Version::HTTP_2, "HTTP/2 flag"),
(
Some(WINHTTP_PROTOCOL_FLAG_HTTP2),
Some("HTTP/1.1"),
Version::HTTP_2,
"HTTP/2 flag takes precedence over string",
),
(Some(WINHTTP_PROTOCOL_FLAG_HTTP3), None, Version::HTTP_3, "HTTP/3 flag"),
(
Some(WINHTTP_PROTOCOL_FLAG_HTTP3 | WINHTTP_PROTOCOL_FLAG_HTTP2),
None,
Version::HTTP_3,
"HTTP/3 takes precedence over HTTP/2",
),
];
for &(flags, version_str, expected, label) in cases {
let result = resolve_version(flags, version_str);
assert_eq!(result, expected, "resolve_version: {label}");
}
}
}