use std::io::Read as _;
use std::time::Duration;
use uni_plugin::{FnError, HttpEgress, HttpResponse};
const ERR_CLIENT_BUILD: u32 = 0xB00;
const ERR_TRANSPORT: u32 = 0xB01;
const ERR_WORKER_PANIC: u32 = 0xB02;
#[derive(Debug, Default, Clone)]
pub struct BlockingHttpEgress;
impl BlockingHttpEgress {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl HttpEgress for BlockingHttpEgress {
fn get(
&self,
url: &str,
timeout: Duration,
max_bytes: usize,
traceparent: Option<&str>,
) -> Result<HttpResponse, FnError> {
run_on_dedicated_thread(url, None, timeout, max_bytes, traceparent)
}
fn post(
&self,
url: &str,
body: &[u8],
timeout: Duration,
max_bytes: usize,
traceparent: Option<&str>,
) -> Result<HttpResponse, FnError> {
run_on_dedicated_thread(url, Some(body), timeout, max_bytes, traceparent)
}
}
fn run_on_dedicated_thread(
url: &str,
body: Option<&[u8]>,
timeout: Duration,
max_bytes: usize,
traceparent: Option<&str>,
) -> Result<HttpResponse, FnError> {
std::thread::scope(|scope| {
let handle = scope.spawn(|| do_request(url, body, timeout, max_bytes, traceparent));
match handle.join() {
Ok(result) => result,
Err(_) => Err(FnError::new(
ERR_WORKER_PANIC,
"http request worker thread panicked",
)),
}
})
}
fn do_request(
url: &str,
body: Option<&[u8]>,
timeout: Duration,
max_bytes: usize,
traceparent: Option<&str>,
) -> Result<HttpResponse, FnError> {
let client = reqwest::blocking::Client::builder()
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| FnError::new(ERR_CLIENT_BUILD, format!("http client build: {e}")))?;
let mut request = match body {
Some(b) => client.post(url).body(b.to_vec()),
None => client.get(url),
};
if let Some(tp) = traceparent {
request = request.header("traceparent", tp);
}
let response = request
.send()
.map_err(|e| FnError::new(ERR_TRANSPORT, format!("http send `{url}`: {e}")))?;
let status = response.status().as_u16();
let mut buf = Vec::new();
let cap = (max_bytes as u64).saturating_add(1);
response
.take(cap)
.read_to_end(&mut buf)
.map_err(|e| FnError::new(ERR_TRANSPORT, format!("http body `{url}`: {e}")))?;
buf.truncate(max_bytes);
Ok(HttpResponse { status, body: buf })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_and_default() {
let _ = BlockingHttpEgress::new();
let _ = BlockingHttpEgress;
}
#[test]
fn redirect_is_not_followed() {
use std::io::Write as _;
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server = std::thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
let _ = std::io::Read::read(&mut stream, &mut buf);
let resp = "HTTP/1.1 302 Found\r\n\
Location: http://169.254.169.254/latest/meta-data\r\n\
Content-Length: 0\r\n\r\n";
let _ = stream.write_all(resp.as_bytes());
}
});
let egress = BlockingHttpEgress::new();
let url = format!("http://{addr}/");
let resp = egress
.get(&url, Duration::from_millis(500), 1024, None)
.expect("redirect response should be returned, not followed");
assert_eq!(resp.status, 302, "redirect must be surfaced, not chased");
let _ = server.join();
}
#[test]
fn invalid_url_is_transport_error_not_panic() {
let egress = BlockingHttpEgress::new();
let err = egress
.get(
"http://127.0.0.1:1/",
Duration::from_millis(200),
1024,
None,
)
.expect_err("connection to a dead port must fail");
assert_eq!(err.code, ERR_TRANSPORT);
}
}