use std::collections::HashMap;
use std::fs::File;
use std::io::{ErrorKind, Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, ToSocketAddrs};
use std::path::Path;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
pub(crate) fn allow_private_fetch() -> bool {
std::env::var("SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub(crate) fn is_blocked_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.octets()[0] == 0
|| (v4.octets()[0] == 100 && (64..128).contains(&v4.octets()[1]))
}
IpAddr::V6(v6) => {
let s = v6.segments();
if s[..5] == [0, 0, 0, 0, 0] && s[5] == 0xffff {
let v4 = Ipv4Addr::new(
(s[6] >> 8) as u8,
(s[6] & 0xff) as u8,
(s[7] >> 8) as u8,
(s[7] & 0xff) as u8,
);
return is_blocked_ip(&IpAddr::V4(v4));
}
v6.is_loopback()
|| v6.is_unspecified()
|| (s[0] & 0xffc0) == 0xfe80 || (s[0] & 0xfe00) == 0xfc00 }
}
}
pub(crate) fn guard_resolve(host: &str, port: u16) -> Result<Vec<IpAddr>, String> {
let ips: Vec<IpAddr> = (host, port)
.to_socket_addrs()
.map_err(|e| format!("resolve {host}: {e}"))?
.map(|sa| sa.ip())
.collect();
if ips.is_empty() {
return Err(format!("cannot resolve host {host}"));
}
if !allow_private_fetch() {
if let Some(bad) = ips.iter().find(|ip| is_blocked_ip(ip)) {
return Err(format!(
"refusing to connect to non-public address {bad} (host {host}) — \
set SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH=1 to override"
));
}
}
Ok(ips)
}
pub(crate) fn parse_http_url(url: &str) -> Result<(bool, String, u16, String), String> {
let (scheme, rest) = url
.split_once("://")
.ok_or_else(|| format!("not an absolute URL: {url}"))?;
let https = match scheme.to_ascii_lowercase().as_str() {
"https" => true,
"http" => false,
other => return Err(format!("unsupported scheme `{other}` in {url}")),
};
let (authority, path) = match rest.find('/') {
Some(i) => (&rest[..i], rest[i..].to_owned()),
None => (rest, "/".to_owned()),
};
let authority = authority.rsplit('@').next().unwrap_or(authority);
let (host, port) = match authority.rsplit_once(':') {
Some((h, p)) => match p.parse::<u16>() {
Ok(port) => (h.to_owned(), port),
Err(_) => (authority.to_owned(), if https { 443 } else { 80 }),
},
None => (authority.to_owned(), if https { 443 } else { 80 }),
};
if host.is_empty() {
return Err(format!("empty host in {url}"));
}
Ok((https, host, port, path))
}
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
const IO_TIMEOUT: Duration = Duration::from_secs(300);
const MAX_HEADERS: usize = 256 * 1024;
const MAX_MEM_BODY: usize = 2 * 1024 * 1024 * 1024;
pub(crate) struct HttpResponse {
pub status: u16,
pub headers: String,
pub body: Vec<u8>,
}
impl std::fmt::Debug for HttpResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpResponse")
.field("status", &self.status)
.field("headers", &self.headers)
.field("body_len", &self.body.len())
.finish()
}
}
trait Stream: Read + Write + Send {}
impl<T: Read + Write + Send> Stream for T {}
struct PooledConn {
stream: Box<dyn Stream>,
idle_since: Instant,
}
fn conn_pool() -> &'static Mutex<HashMap<(String, u16, bool), Vec<PooledConn>>> {
static P: OnceLock<Mutex<HashMap<(String, u16, bool), Vec<PooledConn>>>> = OnceLock::new();
P.get_or_init(|| Mutex::new(HashMap::new()))
}
const KEEPALIVE_IDLE_TTL: Duration = Duration::from_secs(20);
const KEEPALIVE_MAX_PER_HOST: usize = 8;
fn pool_checkout(key: &(String, u16, bool)) -> Option<Box<dyn Stream>> {
let mut pool = conn_pool().lock().ok()?;
let vec = pool.get_mut(key)?;
while let Some(c) = vec.pop() {
if c.idle_since.elapsed() < KEEPALIVE_IDLE_TTL {
return Some(c.stream);
}
}
None
}
fn pool_checkin(key: &(String, u16, bool), stream: Box<dyn Stream>) {
if let Ok(mut pool) = conn_pool().lock() {
let vec = pool.entry(key.clone()).or_default();
if vec.len() < KEEPALIVE_MAX_PER_HOST {
vec.push(PooledConn {
stream,
idle_since: Instant::now(),
});
}
}
}
fn dial(https: bool, host: &str, port: u16, ip: IpAddr) -> Result<Box<dyn Stream>, String> {
let addr = SocketAddr::new(ip, port);
let tcp = TcpStream::connect_timeout(&addr, CONNECT_TIMEOUT)
.map_err(|e| format!("connect {host}:{port}: {e}"))?;
let _ = tcp.set_read_timeout(Some(IO_TIMEOUT));
let _ = tcp.set_write_timeout(Some(IO_TIMEOUT));
if https {
let mut store = rustls::RootCertStore::empty();
store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|e| format!("TLS config: {e}"))?
.with_root_certificates(store)
.with_no_client_auth();
let sni = rustls::pki_types::ServerName::try_from(host.to_owned())
.map_err(|e| format!("bad server name {host}: {e}"))?;
let conn = rustls::ClientConnection::new(Arc::new(config), sni)
.map_err(|e| format!("TLS connect {host}: {e}"))?;
Ok(Box::new(rustls::StreamOwned::new(conn, tcp)))
} else {
Ok(Box::new(tcp))
}
}
fn do_request(
stream: &mut dyn Stream,
host: &str,
path: &str,
extra_headers: &[(&str, &str)],
output: Option<&Path>,
) -> Result<(HttpResponse, bool), String> {
let mut req = format!(
"GET {path} HTTP/1.1\r\nHost: {host}\r\nUser-Agent: supermachine\r\nConnection: keep-alive\r\n"
);
for (k, v) in extra_headers {
req.push_str(&format!("{k}: {v}\r\n"));
}
req.push_str("\r\n");
stream.write_all(req.as_bytes()).map_err(io_str)?;
stream.flush().map_err(io_str)?;
read_response(stream, output)
}
pub(crate) fn http_get_once(
url: &str,
extra_headers: &[(&str, &str)],
output: Option<&Path>,
) -> Result<HttpResponse, String> {
let (https, host, port, path) = parse_http_url(url)?;
let key = (host.clone(), port, https);
if let Some(mut stream) = pool_checkout(&key) {
if let Ok((resp, reusable)) = do_request(&mut *stream, &host, &path, extra_headers, output)
{
if reusable {
pool_checkin(&key, stream);
}
return Ok(resp);
}
}
let ips = guard_resolve(&host, port)?;
let mut stream = dial(https, &host, port, ips[0])?;
let (resp, reusable) = do_request(&mut *stream, &host, &path, extra_headers, output)?;
if reusable {
pool_checkin(&key, stream);
}
Ok(resp)
}
fn io_str(e: std::io::Error) -> String {
format!("io: {e}")
}
fn is_eofish(e: &std::io::Error) -> bool {
matches!(
e.kind(),
ErrorKind::UnexpectedEof | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset
)
}
fn read_response(
r: &mut (impl Read + ?Sized),
output: Option<&Path>,
) -> Result<(HttpResponse, bool), String> {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = [0u8; 32 * 1024];
let header_end = loop {
if let Some(p) = find_sub(&buf, b"\r\n\r\n") {
break p + 4;
}
if buf.len() > MAX_HEADERS {
return Err("response headers exceed 256 KiB".to_owned());
}
match r.read(&mut tmp) {
Ok(0) => return Err("connection closed before headers".to_owned()),
Ok(n) => buf.extend_from_slice(&tmp[..n]),
Err(e) if is_eofish(&e) => return Err("connection closed before headers".to_owned()),
Err(e) => return Err(io_str(e)),
}
};
let headers = String::from_utf8_lossy(&buf[..header_end.saturating_sub(4)]).into_owned();
let mut lines = headers.split("\r\n");
let status_line = lines.next().ok_or("missing status line")?;
let http_1_1 = status_line.starts_with("HTTP/1.1");
let status: u16 = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse().ok())
.ok_or("bad status line")?;
let mut chunked = false;
let mut content_length: Option<usize> = None;
let mut conn_close = false;
let mut conn_keep_alive = false;
for line in lines {
if let Some((k, v)) = line.split_once(':') {
let k = k.trim().to_ascii_lowercase();
let v = v.trim();
if k == "transfer-encoding" && v.eq_ignore_ascii_case("chunked") {
chunked = true;
} else if k == "content-length" {
content_length = v.parse().ok();
} else if k == "connection" {
let lv = v.to_ascii_lowercase();
conn_close |= lv.contains("close");
conn_keep_alive |= lv.contains("keep-alive");
}
}
}
let no_body = matches!(status, 100..=199 | 204 | 304);
let leftover = buf[header_end..].to_vec();
let mut sink = BodySink::new(output)?;
let definite = if no_body {
true
} else if chunked {
read_chunked(r, &leftover, &mut sink)?;
false
} else if let Some(len) = content_length {
sink.write(&leftover[..leftover.len().min(len)])?;
let mut remaining = len.saturating_sub(leftover.len());
while remaining > 0 {
match r.read(&mut tmp) {
Ok(0) => break,
Ok(n) => {
let take = n.min(remaining);
sink.write(&tmp[..take])?;
remaining -= take;
}
Err(e) if is_eofish(&e) => break,
Err(e) => return Err(io_str(e)),
}
}
remaining == 0
} else {
sink.write(&leftover)?;
loop {
match r.read(&mut tmp) {
Ok(0) => break,
Ok(n) => sink.write(&tmp[..n])?,
Err(e) if is_eofish(&e) => break,
Err(e) => return Err(io_str(e)),
}
}
false
};
let keepalive_ok = if http_1_1 {
!conn_close
} else {
conn_keep_alive
};
let reusable = definite && keepalive_ok;
Ok((
HttpResponse {
status,
headers,
body: sink.into_body(),
},
reusable,
))
}
enum BodySink {
File(File),
Mem(Vec<u8>),
}
impl BodySink {
fn new(output: Option<&Path>) -> Result<Self, String> {
match output {
Some(p) => Ok(BodySink::File(
File::create(p).map_err(|e| format!("create {}: {e}", p.display()))?,
)),
None => Ok(BodySink::Mem(Vec::new())),
}
}
fn write(&mut self, data: &[u8]) -> Result<(), String> {
match self {
BodySink::File(f) => f.write_all(data).map_err(io_str),
BodySink::Mem(b) => {
if b.len() + data.len() > MAX_MEM_BODY {
return Err("in-memory response exceeds 2 GiB".to_owned());
}
b.extend_from_slice(data);
Ok(())
}
}
}
fn into_body(self) -> Vec<u8> {
match self {
BodySink::File(_) => Vec::new(),
BodySink::Mem(b) => b,
}
}
}
fn read_chunked(
r: &mut (impl Read + ?Sized),
leftover: &[u8],
sink: &mut BodySink,
) -> Result<(), String> {
let mut buf = leftover.to_vec();
let mut tmp = [0u8; 32 * 1024];
let mut pos = 0usize;
loop {
let nl = loop {
if let Some(p) = find_sub(&buf[pos..], b"\r\n") {
break pos + p;
}
match r.read(&mut tmp) {
Ok(0) => return Err("truncated chunk header".to_owned()),
Ok(n) => buf.extend_from_slice(&tmp[..n]),
Err(e) if is_eofish(&e) => return Err("truncated chunk header".to_owned()),
Err(e) => return Err(io_str(e)),
}
};
let size_hex = std::str::from_utf8(&buf[pos..nl])
.map_err(|_| "non-UTF8 chunk header".to_owned())?
.split(';')
.next()
.unwrap_or("")
.trim();
let size = usize::from_str_radix(size_hex, 16)
.map_err(|_| format!("bad chunk size `{size_hex}`"))?;
pos = nl + 2;
if size == 0 {
return Ok(());
}
while buf.len() < pos + size + 2 {
match r.read(&mut tmp) {
Ok(0) => return Err("truncated chunk body".to_owned()),
Ok(n) => buf.extend_from_slice(&tmp[..n]),
Err(e) if is_eofish(&e) => return Err("truncated chunk body".to_owned()),
Err(e) => return Err(io_str(e)),
}
}
sink.write(&buf[pos..pos + size])?;
pos += size + 2;
}
}
fn find_sub(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
haystack.windows(needle.len()).position(|w| w == needle)
}
pub(crate) fn resolve_redirect(base: &str, loc: &str) -> String {
if loc.contains("://") {
return loc.to_owned();
}
let Ok((https, host, port, path)) = parse_http_url(base) else {
return loc.to_owned();
};
let scheme = if https { "https" } else { "http" };
let default_port = if https { 443 } else { 80 };
let portsuf = if port == default_port {
String::new()
} else {
format!(":{port}")
};
if let Some(rest) = loc.strip_prefix("//") {
format!("{scheme}://{rest}")
} else if loc.starts_with('/') {
format!("{scheme}://{host}{portsuf}{loc}")
} else {
let dir = path.rsplit_once('/').map(|(d, _)| d).unwrap_or("");
format!("{scheme}://{host}{portsuf}{dir}/{loc}")
}
}
pub(crate) fn header_value<'a>(headers: &'a str, name: &str) -> Option<&'a str> {
headers
.lines()
.find(|l| {
l.split_once(':')
.map(|(k, _)| k.trim().eq_ignore_ascii_case(name))
.unwrap_or(false)
})
.and_then(|l| l.split_once(':'))
.map(|(_, v)| v.trim())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn blocks_internal_allows_public() {
for ip in [
"169.254.169.254",
"127.0.0.1",
"10.0.0.5",
"172.16.0.1",
"192.168.1.1",
"100.64.0.1",
"0.0.0.0",
"::1",
"fe80::1",
"fc00::1",
"::ffff:10.0.0.1",
] {
assert!(
is_blocked_ip(&ip.parse().unwrap()),
"{ip} should be blocked"
);
}
for ip in ["8.8.8.8", "1.1.1.1", "140.82.112.3", "2606:4700::1111"] {
assert!(
!is_blocked_ip(&ip.parse().unwrap()),
"{ip} should be allowed"
);
}
}
#[test]
fn url_parse() {
assert_eq!(
parse_http_url("https://registry-1.docker.io/v2/x").unwrap(),
(true, "registry-1.docker.io".into(), 443, "/v2/x".into())
);
assert_eq!(
parse_http_url("http://localhost:5000/v2/").unwrap(),
(false, "localhost".into(), 5000, "/v2/".into())
);
assert!(parse_http_url("ftp://x/y").is_err());
}
#[test]
fn redirect_resolution_forms() {
let base = "https://h.example/dir/page";
assert_eq!(resolve_redirect(base, "https://other/x"), "https://other/x");
assert_eq!(resolve_redirect(base, "/abs"), "https://h.example/abs");
assert_eq!(
resolve_redirect(base, "rel.bin"),
"https://h.example/dir/rel.bin"
);
assert_eq!(resolve_redirect(base, "//cdn/y"), "https://cdn/y");
}
#[test]
fn header_value_case_insensitive() {
let h = "HTTP/1.1 302 Found\r\nLocation: https://elsewhere/x\r\nContent-Length: 0";
assert_eq!(header_value(h, "location"), Some("https://elsewhere/x"));
assert_eq!(header_value(h, "CONTENT-LENGTH"), Some("0"));
assert_eq!(header_value(h, "x-missing"), None);
}
struct OneByteReader<'a> {
data: &'a [u8],
pos: usize,
}
impl Read for OneByteReader<'_> {
fn read(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
if self.pos >= self.data.len() || out.is_empty() {
return Ok(0);
}
out[0] = self.data[self.pos];
self.pos += 1;
Ok(1)
}
}
#[test]
fn reads_exactly_content_length_not_trailing() {
let msg = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhelloTRAILING_GARBAGE";
let mut r = OneByteReader { data: msg, pos: 0 };
let (resp, reusable) = read_response(&mut r, None).unwrap();
assert_eq!(resp.status, 200);
assert_eq!(resp.body, b"hello");
assert!(
reusable,
"content-length response should be keep-alive reusable"
);
}
#[test]
fn reads_chunked_body() {
let msg =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nWiki\r\n5\r\npedia\r\n0\r\n\r\n";
let mut r = OneByteReader { data: msg, pos: 0 };
let (resp, reusable) = read_response(&mut r, None).unwrap();
assert_eq!(resp.body, b"Wikipedia");
assert!(!reusable, "chunked response must not be pooled");
}
#[test]
fn reads_to_eof_without_length() {
let msg = b"HTTP/1.1 200 OK\r\nServer: x\r\n\r\nbody-bytes";
let mut r = OneByteReader { data: msg, pos: 0 };
let (resp, reusable) = read_response(&mut r, None).unwrap();
assert_eq!(resp.body, b"body-bytes");
assert!(!reusable, "read-to-EOF response must not be pooled");
}
#[test]
fn no_body_status_has_empty_body_and_is_reusable() {
for status in ["204 No Content", "304 Not Modified"] {
let msg = format!("HTTP/1.1 {status}\r\nContent-Length: 7\r\n\r\n");
let mut r = OneByteReader {
data: msg.as_bytes(),
pos: 0,
};
let (resp, reusable) = read_response(&mut r, None).unwrap();
assert!(resp.body.is_empty(), "{status}: body must be empty");
assert!(reusable, "{status}: bodyless response should be reusable");
}
let mut r = OneByteReader {
data: b"HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n",
pos: 0,
};
let (_resp, reusable) = read_response(&mut r, None).unwrap();
assert!(
!reusable,
"Connection: close must bar reuse even with no body"
);
}
#[test]
fn header_overflow_rejected() {
let big = vec![b'x'; MAX_HEADERS + 1];
let mut r = std::io::Cursor::new(big);
assert!(read_response(&mut r, None).is_err());
}
#[test]
fn malformed_status_line_rejected() {
let msg = b"GARBAGE NOT A STATUS\r\nContent-Length: 0\r\n\r\n";
let mut r = std::io::Cursor::new(msg.to_vec());
let err = read_response(&mut r, None).unwrap_err();
assert!(err.contains("status"), "got: {err}");
}
#[test]
fn connection_closed_before_headers_rejected() {
let mut r = std::io::Cursor::new(b"HTTP/1.1 200 OK\r\n".to_vec());
assert!(read_response(&mut r, None).is_err());
}
#[test]
fn truncated_chunk_rejected() {
let msg = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nhi";
let mut r = std::io::Cursor::new(msg.to_vec());
assert!(read_response(&mut r, None).is_err());
}
#[test]
fn streams_body_to_file() {
let msg = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhelloIGNORED";
let mut r = OneByteReader { data: msg, pos: 0 };
let tmp = std::env::temp_dir().join(format!("sm-net-test-{}.bin", std::process::id()));
let (resp, _reusable) = read_response(&mut r, Some(&tmp)).unwrap();
assert_eq!(resp.status, 200);
assert!(
resp.body.is_empty(),
"body should be on disk, not in memory"
);
assert_eq!(std::fs::read(&tmp).unwrap(), b"hello");
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn tls_handshake_failure_errors() {
std::env::set_var("SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH", "1");
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let h = std::thread::spawn(move || {
if let Ok((mut s, _)) = listener.accept() {
let _ = s.write_all(b"HTTP/1.1 200 OK\r\n\r\nplaintext");
}
});
let url = format!("https://127.0.0.1:{port}/");
let res = http_get_once(&url, &[], None);
assert!(res.is_err(), "expected TLS handshake failure, got {res:?}");
let _ = h.join();
}
#[test]
fn keep_alive_reuses_one_connection() {
use std::sync::atomic::{AtomicUsize, Ordering};
std::env::set_var("SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH", "1");
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let accepts = Arc::new(AtomicUsize::new(0));
let accepts_srv = accepts.clone();
let h = std::thread::spawn(move || {
if let Ok((mut s, _)) = listener.accept() {
accepts_srv.fetch_add(1, Ordering::SeqCst);
let mut buf = [0u8; 1024];
for body in ["first", "second"] {
let mut got = Vec::new();
loop {
match s.read(&mut buf) {
Ok(0) => return,
Ok(n) => got.extend_from_slice(&buf[..n]),
Err(_) => return,
}
if find_sub(&got, b"\r\n\r\n").is_some() {
break;
}
}
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: keep-alive\r\n\r\n{}",
body.len(),
body
);
if s.write_all(resp.as_bytes()).is_err() {
return;
}
}
}
});
let url = format!("http://127.0.0.1:{port}/");
let r1 = http_get_once(&url, &[], None).unwrap();
assert_eq!(r1.body, b"first");
let r2 = http_get_once(&url, &[], None).unwrap();
assert_eq!(r2.body, b"second");
assert_eq!(
accepts.load(Ordering::SeqCst),
1,
"second request should reuse the pooled connection, not dial again"
);
let _ = h.join();
}
fn drain_one_request(s: &mut std::net::TcpStream) -> bool {
let mut buf = [0u8; 1024];
let mut got = Vec::new();
loop {
match s.read(&mut buf) {
Ok(0) => return false,
Ok(n) => got.extend_from_slice(&buf[..n]),
Err(_) => return false,
}
if find_sub(&got, b"\r\n\r\n").is_some() {
return true;
}
}
}
#[test]
fn connection_close_response_is_not_pooled() {
use std::sync::atomic::{AtomicUsize, Ordering};
std::env::set_var("SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH", "1");
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let accepts = Arc::new(AtomicUsize::new(0));
let accepts_srv = accepts.clone();
let h = std::thread::spawn(move || {
for body in ["first", "second"] {
let Ok((mut s, _)) = listener.accept() else {
return;
};
accepts_srv.fetch_add(1, Ordering::SeqCst);
if !drain_one_request(&mut s) {
return;
}
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
let _ = s.write_all(resp.as_bytes());
}
});
let url = format!("http://127.0.0.1:{port}/");
assert_eq!(http_get_once(&url, &[], None).unwrap().body, b"first");
assert_eq!(http_get_once(&url, &[], None).unwrap().body, b"second");
assert_eq!(
accepts.load(Ordering::SeqCst),
2,
"Connection: close must not be pooled — second request should dial fresh"
);
let _ = h.join();
}
#[test]
fn dead_pooled_connection_redials() {
use std::sync::atomic::{AtomicUsize, Ordering};
std::env::set_var("SUPERMACHINE_BUILD_ALLOW_PRIVATE_FETCH", "1");
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let accepts = Arc::new(AtomicUsize::new(0));
let accepts_srv = accepts.clone();
let h = std::thread::spawn(move || {
if let Ok((mut s, _)) = listener.accept() {
accepts_srv.fetch_add(1, Ordering::SeqCst);
if drain_one_request(&mut s) {
let _ = s.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: keep-alive\r\n\r\nfirst",
);
}
let _ = s.shutdown(std::net::Shutdown::Both);
}
if let Ok((mut s, _)) = listener.accept() {
accepts_srv.fetch_add(1, Ordering::SeqCst);
if drain_one_request(&mut s) {
let _ = s.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\nConnection: close\r\n\r\nsecond",
);
}
}
});
let url = format!("http://127.0.0.1:{port}/");
assert_eq!(http_get_once(&url, &[], None).unwrap().body, b"first");
std::thread::sleep(std::time::Duration::from_millis(50));
assert_eq!(http_get_once(&url, &[], None).unwrap().body, b"second");
assert_eq!(
accepts.load(Ordering::SeqCst),
2,
"a dead pooled connection must trigger a fresh dial, not a hard error"
);
let _ = h.join();
}
}