use std::{
io,
io::Read,
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
time::Duration,
};
use ureq::Agent;
pub use super::ServerClient;
use super::error::*;
const TIMEOUT: Duration = Duration::from_secs(180);
const HOST: &str = "127.0.0.1";
#[derive(Debug)]
pub struct HttpClient {
agent: Agent,
pid_id: String,
base_url: String,
pub host: String,
pub port: u16,
}
impl HttpClient {
pub fn new(executable_name: &str, host: Option<&str>, port: Option<u16>) -> Result<Self> {
let host = host.unwrap_or_else(|| HOST).to_string();
let port = if let Some(port) = port {
port
} else {
let listener: TcpListener =
TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).map_err(
|e| ClientError::Setup {
reason: format!("failed to obtain an ephemeral port: {e}"),
},
)?;
let port = listener
.local_addr()
.map_err(|e| ClientError::Setup {
reason: format!("could not read local address: {e}"),
})?
.port();
drop(listener); port
};
let agent = Agent::new_with_config(
Agent::config_builder()
.timeout_global(Some(TIMEOUT)) .build(),
);
let pid_id = sanitize_filename::sanitize(
format!("{executable_name}_http_{host}_{port}").to_ascii_lowercase(),
);
if pid_id.len() > 240 {
return Err(ClientError::Setup {
reason: format!("pid_id \"{pid_id}\" exceeds 240 characters"),
});
}
let client = Self {
base_url: format!("http://{host}:{port}"),
host,
agent,
port,
pid_id,
};
crate::trace!("Client created: {client}");
Ok(client)
}
fn send(&self, verb: &'static str, path: &str, body: Option<&[u8]>) -> Result<Vec<u8>> {
debug_assert!(path.starts_with('/'));
let url = format!("{}{}", self.base_url, path);
let response = match (verb, body) {
("GET", _) => self.agent.get(&url).call(),
("POST", Some(b)) if !b.is_empty() => self
.agent
.post(&url)
.content_type("application/json")
.send(b),
("POST", _) => self
.agent
.post(&url)
.content_type("application/json")
.send_empty(),
_ => unreachable!("unsupported verb"),
};
match response {
Ok(resp) if (200..300).contains(&resp.status().as_u16()) => {
let mut body = Vec::new();
resp.into_body().into_reader().read_to_end(&mut body)?;
Ok(body)
}
Ok(resp) => Err(ClientError::Remote {
code: resp.status().as_u16(),
message: resp
.status()
.canonical_reason()
.unwrap_or("unknown error")
.to_string(),
}),
Err(ureq::Error::StatusCode(code)) => Err(ClientError::Remote {
code,
message: format!("HTTP {code}"),
}),
Err(ureq::Error::Timeout(_)) => Err(ClientError::Timeout(TIMEOUT)),
Err(ureq::Error::Io(e)) => Err(ClientError::Io(e)),
Err(ureq::Error::Protocol(p)) => Err(ClientError::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("protocol error: {p}"),
))),
Err(ureq::Error::BadUri(u)) => Err(ClientError::Setup {
reason: format!("bad URI: {u}"),
}),
Err(other) => Err(ClientError::Io(io::Error::new(
io::ErrorKind::Other,
format!("ureq error: {other}"),
))),
}
}
#[cfg(test)]
pub fn dummy() -> Self {
HttpClient {
base_url: format!("{HOST}:0"),
agent: Agent::new_with_config(
Agent::config_builder()
.timeout_global(Some(TIMEOUT)) .build(),
),
port: 0,
pid_id: "dummy_http_client.pid".to_string(),
host: HOST.to_string(),
}
}
}
impl std::fmt::Display for HttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HttpClient({:#?})", self.base_url)
}
}
impl ServerClient for HttpClient {
fn get_raw(&self, path: &str) -> Result<Vec<u8>> {
self.send("GET", path, None) }
fn post_raw(&self, path: &str, body: &[u8]) -> Result<Vec<u8>> {
self.send("POST", path, Some(body)) }
fn stop(&self) -> Result<()> {
Ok(())
}
fn host(&self) -> String {
self.host.to_string()
}
fn pid_id(&self) -> String {
self.pid_id.to_string()
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use serial_test::serial;
use super::*;
use crate::server::ipc::ServerClientExt;
fn spawn_http_server(
port: u16,
response: &'static [u8],
keep_open_ms: u64,
) -> std::thread::JoinHandle<()> {
use std::{
io::{Read, Write},
net::{Shutdown, TcpListener},
};
let listener = TcpListener::bind(("127.0.0.1", port)).expect("bind test HTTP socket");
let reply = response.to_vec();
std::thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let _ = {
let mut buf = [0u8; 512];
stream.read(&mut buf).ok()
};
stream.write_all(&reply).unwrap();
stream.flush().unwrap();
if keep_open_ms > 0 {
std::thread::sleep(std::time::Duration::from_millis(keep_open_ms));
}
let _ = stream.shutdown(Shutdown::Write);
}
})
}
fn make_reply(status_line: &str, headers: &[(&str, String)], body: &[u8]) -> String {
let mut msg = String::new();
msg.push_str(status_line);
msg.push_str("\r\n");
for (k, v) in headers {
msg.push_str(k);
msg.push_str(": ");
msg.push_str(v);
msg.push_str("\r\n");
}
msg.push_str("\r\n");
msg.push_str(std::str::from_utf8(body).unwrap());
msg
}
#[test]
#[serial]
fn http_response_scenarios() {
struct Case {
name: &'static str,
reply: String,
keep_open: u64, expect: std::result::Result<serde_json::Value, u16>,
}
let fixed_body = br#"{"ok":true}"#;
let fixed_ok = make_reply(
"HTTP/1.1 200 OK",
&[
("Content-Length", fixed_body.len().to_string()),
("Content-Type", "application/json".into()),
("Connection", "close".into()),
],
fixed_body,
);
let chunk_body = br#"{"hello":"world"}"#;
let chunked_ok = {
let mut m = String::new();
use std::fmt::Write as _;
write!(
&mut m,
"HTTP/1.1 200 OK\r\n\
Transfer-Encoding: chunked\r\n\
Connection: close\r\n\r\n\
{:x}\r\n{}\r\n0\r\n\r\n",
chunk_body.len(),
std::str::from_utf8(chunk_body).unwrap(),
)
.unwrap();
m
};
let err_500 = make_reply(
"HTTP/1.1 500 Internal Server Error",
&[
("Content-Length", "0".into()),
("Connection", "close".into()),
],
b"",
);
let cases = vec![
Case {
name: "fixed_len_ok",
reply: fixed_ok,
keep_open: 0, expect: Ok(json!({"ok": true})),
},
Case {
name: "chunked_ok",
reply: chunked_ok,
keep_open: 0,
expect: Ok(json!({"hello": "world"})),
},
Case {
name: "err_500",
reply: err_500,
keep_open: 0,
expect: Err(500),
},
];
for c in cases {
let client =
HttpClient::new("dummy_http_client", None, None).expect("create HttpClient");
let port = client.port;
let reply_bytes: &'static [u8] =
Box::<[u8]>::leak(c.reply.clone().into_bytes().into_boxed_slice());
let _srv = spawn_http_server(port, reply_bytes, c.keep_open);
match c.expect {
Ok(ref wanted) => {
let got: serde_json::Value = client.get("/").expect("request should succeed");
assert_eq!(&got, wanted, "case `{}` JSON mismatch", c.name);
}
Err(code) => {
let err = client
.get::<serde_json::Value>("/")
.expect_err("expected error");
match err {
ClientError::Remote { code: c2, .. } => {
assert_eq!(c2, code, "case `{}` wrong status code", c.name)
}
other => panic!("case `{}` expected Remote error, got {other:?}", c.name),
}
}
}
}
}
#[test]
#[serial]
fn http_post_sends_body_and_parses_response() {
let body = br#"{"ack":true}"#;
let reply = make_reply(
"HTTP/1.1 200 OK",
&[
("Content-Length", body.len().to_string()),
("Content-Type", "application/json".into()),
("Connection", "close".into()),
],
body,
);
let client = HttpClient::new("dummy_http_client", None, None).expect("create HttpClient");
let port = client.port;
let reply_bytes: &'static [u8] = Box::<[u8]>::leak(reply.into_bytes().into_boxed_slice());
let _srv = spawn_http_server(port, reply_bytes, 0);
let payload = json!({"msg":"hi"});
let v: serde_json::Value = client
.post("/", &payload)
.expect("POST request should succeed");
assert_eq!(v, json!({"ack": true}));
}
}