use std::io::Read;
use std::time::Duration;
use host_identity::transport::HttpTransport;
const MAX_RESPONSE_BYTES: u64 = 1 << 20;
pub(crate) const DEFAULT_NETWORK_TIMEOUT: Duration = Duration::from_millis(750);
#[derive(Clone)]
pub(crate) struct UreqTransport {
agent: ureq::Agent,
}
impl UreqTransport {
pub(crate) fn with_timeout(timeout: Duration) -> Self {
let connect = timeout / 2;
let agent = ureq::AgentBuilder::new()
.timeout_connect(connect)
.timeout(timeout)
.build();
Self { agent }
}
}
#[derive(Debug)]
pub(crate) struct UreqError(String);
impl std::fmt::Display for UreqError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ureq transport error: {}", self.0)
}
}
impl std::error::Error for UreqError {}
impl HttpTransport for UreqTransport {
type Error = UreqError;
fn send(
&self,
request: http::Request<Vec<u8>>,
) -> Result<http::Response<Vec<u8>>, Self::Error> {
let (parts, body) = request.into_parts();
let uri = parts.uri.to_string();
let mut req = self.agent.request(parts.method.as_str(), &uri);
for (name, value) in &parts.headers {
if let Ok(v) = value.to_str() {
req = req.set(name.as_str(), v);
}
}
let response = if body.is_empty() {
req.call()
} else {
req.send_bytes(&body)
};
let ureq_resp = match response {
Ok(r) | Err(ureq::Error::Status(_, r)) => r,
Err(ureq::Error::Transport(t)) => return Err(UreqError(t.to_string())),
};
into_http_response(ureq_resp)
}
}
fn into_http_response(resp: ureq::Response) -> Result<http::Response<Vec<u8>>, UreqError> {
let mut builder = http::Response::builder().status(resp.status());
for name in resp.headers_names() {
if let Some(value) = resp.header(&name) {
builder = builder.header(&name, value);
}
}
let body = read_capped_body(resp)?;
builder.body(body).map_err(|e| UreqError(e.to_string()))
}
fn read_capped_body(resp: ureq::Response) -> Result<Vec<u8>, UreqError> {
let mut body = Vec::new();
resp.into_reader()
.take(MAX_RESPONSE_BYTES + 1)
.read_to_end(&mut body)
.map_err(|e| UreqError(e.to_string()))?;
if body.len() as u64 > MAX_RESPONSE_BYTES {
return Err(UreqError(format!(
"response body exceeded {MAX_RESPONSE_BYTES} bytes"
)));
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::net::{TcpListener, TcpStream};
use std::thread::{self, JoinHandle};
use std::time::Instant;
const MOCK_SERVER_DEADLINE: Duration = Duration::from_secs(10);
struct MockServer {
url: String,
handle: Option<JoinHandle<()>>,
}
impl MockServer {
fn serve_once(response: Vec<u8>) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
listener
.set_nonblocking(true)
.expect("set_nonblocking on listener");
let handle = thread::spawn(move || {
let deadline = Instant::now() + MOCK_SERVER_DEADLINE;
let Some(mut stream) = accept_with_deadline(&listener, deadline) else {
return;
};
configure_server_stream(&mut stream);
drain_request(&mut stream);
let _ = stream.write_all(&response);
});
Self {
url: format!("http://{addr}/"),
handle: Some(handle),
}
}
}
fn accept_with_deadline(listener: &TcpListener, deadline: Instant) -> Option<TcpStream> {
loop {
match listener.accept() {
Ok((stream, _)) => return Some(stream),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if Instant::now() >= deadline {
return None;
}
thread::sleep(Duration::from_millis(10));
}
Err(_) => return None,
}
}
}
fn configure_server_stream(stream: &mut TcpStream) {
let _ = stream.set_nonblocking(false);
let _ = stream.set_read_timeout(Some(MOCK_SERVER_DEADLINE));
let _ = stream.set_write_timeout(Some(MOCK_SERVER_DEADLINE));
}
impl Drop for MockServer {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
fn drain_request(stream: &mut TcpStream) {
use std::io::Read;
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf);
}
fn get(url: &str) -> Result<http::Response<Vec<u8>>, UreqError> {
let transport = UreqTransport::with_timeout(Duration::from_secs(5));
let request = http::Request::builder()
.method("GET")
.uri(url)
.body(Vec::new())
.expect("request builder");
transport.send(request)
}
#[test]
fn returns_non_2xx_as_response_not_error() {
let server = MockServer::serve_once(
b"HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found".to_vec(),
);
let resp = get(&server.url).expect("404 must not be a transport error");
assert_eq!(resp.status(), 404);
assert_eq!(resp.body(), b"not found");
}
#[test]
fn propagates_2xx_body_and_status() {
let server = MockServer::serve_once(
b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nhello-hostid!".to_vec(),
);
let resp = get(&server.url).expect("200 must succeed");
assert_eq!(resp.status(), 200);
assert_eq!(resp.body(), b"hello-hostid!");
}
#[test]
fn rejects_oversize_response_body() {
let cap = usize::try_from(MAX_RESPONSE_BYTES).expect("cap fits in usize on test targets");
let big = vec![b'x'; cap + 16];
let mut payload =
format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", big.len()).into_bytes();
payload.extend_from_slice(&big);
let server = MockServer::serve_once(payload);
let err = get(&server.url).expect_err("oversize body must error");
let msg = err.to_string();
assert!(
msg.contains("exceeded") || msg.contains("closed before"),
"oversize body should surface as cap or premature-close error: {err}",
);
}
#[test]
fn transport_error_when_connection_refused() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
drop(listener);
let err =
get(&format!("http://{addr}/")).expect_err("closed port must return a transport error");
let msg = err.to_string().to_ascii_lowercase();
assert!(
msg.contains("refused") || msg.contains("connect"),
"expected connection-refused-style transport error, got: {err}"
);
}
}