#![allow(
dead_code,
clippy::doc_markdown,
clippy::struct_excessive_bools,
clippy::fn_params_excessive_bools
)]
use std::io::{BufRead, BufReader, Read};
use std::net::TcpStream;
use std::path::Path;
use std::process::{Child, ChildStderr, Command, Stdio};
use std::thread;
use std::time::{Duration, Instant};
pub const STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct ServerAddrs {
pub native: String,
pub http: Option<String>,
pub pgwire: Option<String>,
pub repl: Option<String>,
}
pub struct ServerBuilder {
extra_args: Vec<String>,
extra_env: Vec<(String, String)>,
env_remove: Vec<String>,
want_http: bool,
want_pgwire: bool,
want_repl: bool,
startup_timeout: Duration,
inherit_stderr_echo: bool,
}
impl Default for ServerBuilder {
fn default() -> Self {
Self {
extra_args: Vec::new(),
extra_env: Vec::new(),
env_remove: alloc_default_env_remove(),
want_http: false,
want_pgwire: false,
want_repl: false,
startup_timeout: STARTUP_TIMEOUT,
inherit_stderr_echo: false,
}
}
}
fn alloc_default_env_remove() -> Vec<String> {
vec![
"SPG_PASSWORD".into(),
"SPG_ADMIN_PASSWORD".into(),
"SPG_PG_ADDR".into(),
]
}
impl ServerBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn arg(mut self, a: impl Into<String>) -> Self {
self.extra_args.push(a.into());
self
}
#[must_use]
pub fn arg_path(mut self, p: &Path) -> Self {
self.extra_args.push(p.to_string_lossy().into_owned());
self
}
#[must_use]
pub fn env(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
self.extra_env.push((k.into(), v.into()));
self
}
#[must_use]
pub fn keep_env(mut self, k: &str) -> Self {
self.env_remove.retain(|x| x != k);
self
}
#[must_use]
pub fn with_http(mut self) -> Self {
self.want_http = true;
self.extra_env
.push(("SPG_HTTP_ADDR".into(), "127.0.0.1:0".into()));
self
}
#[must_use]
pub fn with_pgwire(mut self) -> Self {
self.want_pgwire = true;
self.env_remove.retain(|k| k != "SPG_PG_ADDR");
self.extra_env
.push(("SPG_PG_ADDR".into(), "127.0.0.1:0".into()));
self
}
#[must_use]
pub fn with_repl(mut self) -> Self {
self.want_repl = true;
self.extra_env
.push(("SPG_REPL_ADDR".into(), "127.0.0.1:0".into()));
self
}
#[must_use]
pub fn with_logical_wal(mut self) -> Self {
self.extra_env
.push(("SPG_WAL_LEVEL".into(), "logical".into()));
self
}
#[must_use]
pub fn echo_stderr(mut self, on: bool) -> Self {
self.inherit_stderr_echo = on;
self
}
#[must_use]
pub fn startup_timeout(mut self, d: Duration) -> Self {
self.startup_timeout = d;
self
}
pub fn spawn(self) -> (Child, ServerAddrs) {
let mut cmd = Command::new(env!("CARGO_BIN_EXE_spg-server"));
cmd.arg("127.0.0.1:0");
for a in &self.extra_args {
cmd.arg(a);
}
cmd.stdout(Stdio::null()).stderr(Stdio::piped());
for k in &self.env_remove {
cmd.env_remove(k);
}
for (k, v) in &self.extra_env {
cmd.env(k, v);
}
let mut child = cmd.spawn().expect("spawn spg-server");
let stderr = child.stderr.take().expect("piped stderr");
let addrs = read_listener_addrs(
&mut child,
stderr,
self.startup_timeout,
self.want_http,
self.want_pgwire,
self.want_repl,
self.inherit_stderr_echo,
);
(child, addrs)
}
pub fn spawn_expecting_startup_failure(self) -> Child {
let mut cmd = Command::new(env!("CARGO_BIN_EXE_spg-server"));
cmd.arg("127.0.0.1:0");
for a in &self.extra_args {
cmd.arg(a);
}
cmd.stdout(Stdio::null()).stderr(Stdio::null());
for k in &self.env_remove {
cmd.env_remove(k);
}
for (k, v) in &self.extra_env {
cmd.env(k, v);
}
cmd.spawn().expect("spawn spg-server")
}
}
fn read_listener_addrs(
child: &mut Child,
stderr: ChildStderr,
deadline: Duration,
want_http: bool,
want_pgwire: bool,
want_repl: bool,
inherit_echo: bool,
) -> ServerAddrs {
let mut reader = BufReader::new(stderr);
let until = Instant::now() + deadline;
let mut native: Option<String> = None;
let mut http: Option<String> = None;
let mut pgwire: Option<String> = None;
let mut repl: Option<String> = None;
let mut line = String::new();
while Instant::now() < until {
if native.is_some()
&& (!want_http || http.is_some())
&& (!want_pgwire || pgwire.is_some())
&& (!want_repl || repl.is_some())
{
break;
}
line.clear();
match reader.read_line(&mut line) {
Ok(0) => {
if let Ok(Some(status)) = child.try_wait() {
panic!("server exited before publishing addrs: {status:?}");
}
thread::sleep(Duration::from_millis(20));
}
Ok(_) => {
if inherit_echo {
eprint!("{line}");
}
if let Some(a) = extract("http listening on ", &line) {
http = Some(a);
} else if let Some(a) = extract("pg-wire listening on ", &line) {
pgwire = Some(a);
} else if let Some(a) = extract("replication listening on ", &line) {
repl = Some(a);
} else if let Some(a) = extract("listening on ", &line) {
native = Some(a);
}
}
Err(e) => panic!("read stderr: {e}"),
}
}
let Some(n) = native else {
let _ = child.kill();
panic!("server didn't publish native listen addr within {deadline:?}");
};
if want_http && http.is_none() {
let _ = child.kill();
panic!("server didn't publish http addr within {deadline:?}");
}
if want_pgwire && pgwire.is_none() {
let _ = child.kill();
panic!("server didn't publish pg-wire addr within {deadline:?}");
}
if want_repl && repl.is_none() {
let _ = child.kill();
panic!("server didn't publish replication addr within {deadline:?}");
}
thread::spawn(move || {
if inherit_echo {
let mut buf = String::new();
while let Ok(n) = reader.read_line(&mut buf) {
if n == 0 {
break;
}
eprint!("{buf}");
buf.clear();
}
} else {
let mut sink = String::new();
let _ = reader.read_to_string(&mut sink);
}
});
ServerAddrs {
native: n,
http,
pgwire,
repl,
}
}
fn extract(marker: &str, line: &str) -> Option<String> {
let after = line.find(marker)?;
let tail = &line[after + marker.len()..];
let end = tail.find([' ', '\n', '\r']).unwrap_or(tail.len());
Some(tail[..end].to_string())
}
pub struct ChildGuard(pub Child);
impl Drop for ChildGuard {
fn drop(&mut self) {
let _ = self.0.kill();
let _ = self.0.wait();
}
}
pub fn rss_kib_of(pid: u32) -> u64 {
let out = Command::new("ps")
.arg("-o")
.arg("rss=")
.arg("-p")
.arg(pid.to_string())
.output();
let Ok(out) = out else { return 0 };
if !out.status.success() {
return 0;
}
String::from_utf8_lossy(&out.stdout)
.trim()
.parse::<u64>()
.unwrap_or(0)
}
pub fn connect_to(addr: &str) -> TcpStream {
let deadline = Instant::now() + STARTUP_TIMEOUT;
loop {
match TcpStream::connect(addr) {
Ok(s) => return s,
Err(e) => {
assert!(Instant::now() < deadline, "connect {addr}: {e}");
thread::sleep(Duration::from_millis(10));
}
}
}
}