use std::process::{Command, Stdio};
use once_cell::sync::Lazy;
use std::sync::{Arc, Mutex};
use eyre::{eyre, WrapErr};
use owo_colors::OwoColorize;
use pgx::*;
use pgx_utils::pg_config::{PgConfig, Pgx};
use pgx_utils::{createdb, get_named_capture, get_target_dir};
use postgres::error::DbError;
use postgres::Client;
use std::collections::HashMap;
use std::fmt::Write as _;
use std::io::{BufRead, BufReader, Write};
use std::path::PathBuf;
type LogLines = Arc<Mutex<HashMap<String, Vec<String>>>>;
struct SetupState {
installed: bool,
loglines: LogLines,
system_session_id: String,
}
static TEST_MUTEX: Lazy<Mutex<SetupState>> = Lazy::new(|| {
Mutex::new(SetupState {
installed: false,
loglines: Arc::new(Mutex::new(HashMap::new())),
system_session_id: "NONE".to_string(),
})
});
static SHUTDOWN_HOOKS: Lazy<Mutex<Vec<Box<dyn Fn() + Send>>>> =
Lazy::new(|| Mutex::new(Vec::new()));
fn register_shutdown_hook() {
extern "C" fn run_shutdown_hooks() {
for func in SHUTDOWN_HOOKS.lock().unwrap().iter() {
func();
}
}
shutdown_hooks::add_shutdown_hook(run_shutdown_hooks);
}
pub fn add_shutdown_hook<F: Fn()>(func: F)
where
F: Send + 'static,
{
SHUTDOWN_HOOKS.lock().unwrap().push(Box::new(func));
}
pub fn run_test(
sql_funcname: &str,
expected_error: Option<&str>,
postgresql_conf: Vec<&'static str>,
) -> eyre::Result<()> {
let (loglines, system_session_id) = initialize_test_framework(postgresql_conf)?;
let (mut client, session_id) = client();
let schema = "tests"; let result = match client.transaction() {
Ok(mut tx) => {
let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();"));
if result.is_ok() {
tx.rollback().expect("test rollback didn't work");
}
result
}
Err(e) => panic!("attempt to run test tx failed:\n{e}"),
};
if let Err(e) = result {
let error_as_string = format!("error in test tx: {e}");
let cause = e.into_source();
if let Some(e) = cause {
if let Some(dberror) = e.downcast_ref::<DbError>() {
let received_error_message: &str = dberror.message();
if let Some(expected_error_message) = expected_error {
assert_eq!(received_error_message, expected_error_message);
Ok(())
} else {
std::thread::sleep(std::time::Duration::from_millis(1000));
let mut pg_location = String::from("Postgres location: ");
pg_location.push_str(match dberror.file() {
Some(file) => file,
None => "<unknown>",
});
if let Some(ln) = dberror.line() {
let _ = write!(pg_location, ":{ln}");
};
let mut rust_location = String::from("Rust location: ");
rust_location.push_str(match dberror.where_() {
Some(place) => place,
None => "<unknown>",
});
panic!(
"\n{sys}...\n{sess}\n{e}\n{pg}\n{rs}\n\n",
sys = format_loglines(&system_session_id, &loglines),
sess = format_loglines(&session_id, &loglines),
e = received_error_message.bold().red(),
pg = pg_location.dimmed().white(),
rs = rust_location.yellow()
);
}
} else {
panic!("Failed downcast to DbError:\n{e}")
}
} else {
panic!(
"Error without deeper source cause:\n{e}\n",
e = error_as_string.bold().red()
)
}
} else if let Some(message) = expected_error {
return Err(eyre!("Expected error: {message}"));
} else {
Ok(())
}
}
fn format_loglines(session_id: &str, loglines: &LogLines) -> String {
let mut result = String::new();
for line in loglines
.lock()
.unwrap()
.entry(session_id.to_string())
.or_default()
.iter()
{
result.push_str(line);
result.push('\n');
}
result
}
fn initialize_test_framework(
postgresql_conf: Vec<&'static str>,
) -> eyre::Result<(LogLines, String)> {
let mut state = TEST_MUTEX.lock().unwrap_or_else(|_| {
std::process::exit(1);
});
if !state.installed {
register_shutdown_hook();
install_extension()?;
initdb(postgresql_conf)?;
let system_session_id = start_pg(state.loglines.clone())?;
let pg_config = get_pg_config();
dropdb();
createdb(&pg_config, get_pg_dbname(), true, false).expect("failed to create test database");
create_extension();
state.installed = true;
state.system_session_id = system_session_id;
}
Ok((state.loglines.clone(), state.system_session_id.clone()))
}
fn get_pg_config() -> PgConfig {
let pgx = Pgx::from_config().expect("Unable to load pgx config");
pgx.get(&format!("pg{}", pg_sys::get_pg_major_version_num()))
.expect("not a valid postgres version")
.clone()
}
pub fn client() -> (postgres::Client, String) {
fn determine_session_id(client: &mut Client) -> String {
let result = client.query("SELECT to_hex(trunc(EXTRACT(EPOCH FROM backend_start))::integer) || '.' || to_hex(pid) AS sid FROM pg_stat_activity WHERE pid = pg_backend_pid();", &[]).expect("failed to determine session id");
match result.get(0) {
Some(row) => row.get::<&str, &str>("sid").to_string(),
None => panic!("No session id returned from query"),
}
}
let pg_config = get_pg_config();
let mut client = postgres::Config::new()
.host(pg_config.host())
.port(
pg_config
.test_port()
.expect("unable to determine test port"),
)
.user(&get_pg_user())
.dbname(&get_pg_dbname())
.connect(postgres::NoTls)
.unwrap();
let session_id = determine_session_id(&mut client);
client
.simple_query("SET log_min_messages TO 'INFO';")
.expect("FAILED: SET log_min_messages TO 'INFO'");
client
.simple_query("SET log_min_duration_statement TO 1000;")
.expect("FAILED: SET log_min_duration_statement TO 1000");
client
.simple_query("SET log_statement TO 'all';")
.expect("FAILED: SET log_statement TO 'all'");
(client, session_id)
}
fn install_extension() -> eyre::Result<()> {
eprintln!("installing extension");
let is_release = std::env::var("PGX_BUILD_PROFILE").unwrap_or("debug".into()) == "release";
let no_schema = std::env::var("PGX_NO_SCHEMA").unwrap_or("false".into()) == "true";
let mut features = std::env::var("PGX_FEATURES").unwrap_or("".to_string());
if !features.contains("pg_test") {
features += " pg_test";
}
let no_default_features =
std::env::var("PGX_NO_DEFAULT_FEATURES").unwrap_or("false".to_string()) == "true";
let all_features = std::env::var("PGX_ALL_FEATURES").unwrap_or("false".to_string()) == "true";
let pg_version = format!("pg{}", pg_sys::get_pg_major_version_string());
let pgx = Pgx::from_config()?;
let pg_config = pgx.get(&pg_version)?;
let mut command = Command::new("cargo");
command
.arg("pgx")
.arg("install")
.arg("--test")
.arg("--pg-config")
.arg(pg_config.path().ok_or(eyre!("No pg_config found"))?)
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.env("CARGO_TARGET_DIR", get_target_dir()?);
if let Ok(manifest_path) = std::env::var("PGX_MANIFEST_PATH") {
command.arg("--manifest-path");
command.arg(manifest_path);
}
if let Ok(rust_log) = std::env::var("RUST_LOG") {
command.env("RUST_LOG", rust_log);
}
if !features.trim().is_empty() {
command.arg("--features");
command.arg(features);
}
if no_default_features {
command.arg("--no-default-features");
}
if all_features {
command.arg("--all-features");
}
if is_release {
command.arg("--release");
}
if no_schema {
command.arg("--no-schema");
}
let mut child = command.spawn().unwrap();
let status = child.wait().unwrap();
if !status.success() {
return Err(eyre!("failed to install extension"));
}
Ok(())
}
fn initdb(postgresql_conf: Vec<&'static str>) -> eyre::Result<()> {
let pg_config = get_pg_config();
let pgdata = get_pgdata_path()?;
if !pgdata.is_dir() {
let status = Command::new(
pg_config
.initdb_path()
.wrap_err("unable to determine initdb path")?,
)
.arg("-D")
.arg(pgdata.to_str().unwrap())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.status()
.unwrap();
if !status.success() {
return Err(eyre!("initdb failed"));
}
}
modify_postgresql_conf(pgdata, postgresql_conf)
}
fn modify_postgresql_conf(pgdata: PathBuf, postgresql_conf: Vec<&'static str>) -> eyre::Result<()> {
let mut postgresql_conf_file = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.open(format!("{}/postgresql.auto.conf", pgdata.display()))
.wrap_err("couldn't open postgresql.auto.conf")?;
postgresql_conf_file
.write_all("log_line_prefix='[%m] [%p] [%c]: '\n".as_bytes())
.wrap_err("couldn't append log_line_prefix")?;
for setting in postgresql_conf {
postgresql_conf_file
.write_all(format!("{setting}\n").as_bytes())
.wrap_err("couldn't append custom setting to postgresql.conf")?;
}
postgresql_conf_file
.write_all(
format!(
"unix_socket_directories = '{}'",
Pgx::home().unwrap().display()
)
.as_bytes(),
)
.wrap_err("couldn't append `unix_socket_directories` setting to postgresql.conf")?;
Ok(())
}
fn start_pg(loglines: LogLines) -> eyre::Result<String> {
let pg_config = get_pg_config();
let mut command = Command::new(
pg_config
.postmaster_path()
.wrap_err("unable to determine postmaster path")?,
);
command
.arg("-D")
.arg(get_pgdata_path()?.to_str().unwrap())
.arg("-h")
.arg(pg_config.host())
.arg("-p")
.arg(
pg_config
.test_port()
.expect("unable to determine test port")
.to_string(),
)
.stdout(Stdio::inherit())
.stderr(Stdio::piped());
let command_str = format!("{command:?}");
let (pgpid, session_id) = monitor_pg(command, command_str, loglines);
add_shutdown_hook(move || unsafe {
let message_string =
std::ffi::CString::new("Stopping Postgres\n\n".bold().blue().to_string()).unwrap();
libc::printf(message_string.as_ptr());
libc::kill(pgpid as libc::pid_t, libc::SIGTERM);
});
Ok(session_id)
}
fn monitor_pg(mut command: Command, cmd_string: String, loglines: LogLines) -> (u32, String) {
let (sender, receiver) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let mut child = command.spawn().expect("postmaster didn't spawn");
let pid = child.id();
eprintln!(
"{cmd}\npid={p}",
cmd = cmd_string.bold().blue(),
p = pid.to_string().yellow()
);
eprintln!("{}", pg_sys::get_pg_version_string().bold().purple());
let reader = BufReader::new(
child
.stderr
.take()
.expect("couldn't take postmaster stderr"),
);
let regex = regex::Regex::new(r#"\[.*?\] \[.*?\] \[(?P<session_id>.*?)\]"#).unwrap();
let mut is_started_yet = false;
let mut lines = reader.lines();
while let Some(Ok(line)) = lines.next() {
let session_id = match get_named_capture(®ex, "session_id", &line) {
Some(sid) => sid,
None => "NONE".to_string(),
};
if line.contains("database system is ready to accept connections") {
sender.send((pid, session_id.clone())).unwrap();
is_started_yet = true;
}
if !is_started_yet || line.contains("TMSG: ") {
eprintln!("{}", line.cyan());
}
let mut loglines = loglines.lock().unwrap();
let session_lines = loglines.entry(session_id).or_insert_with(Vec::new);
session_lines.push(line);
}
match child.try_wait() {
Ok(status) => {
if let Some(_status) = status {
}
}
Err(e) => panic!("was going to let Postgres finish, but errored this time:\n{e}"),
}
});
receiver.recv().expect("Postgres failed to start")
}
fn dropdb() {
let pg_config = get_pg_config();
let output = Command::new(
pg_config
.dropdb_path()
.expect("unable to determine dropdb path"),
)
.env_remove("PGDATABASE")
.env_remove("PGHOST")
.env_remove("PGPORT")
.env_remove("PGUSER")
.arg("--if-exists")
.arg("-h")
.arg(pg_config.host())
.arg("-p")
.arg(
pg_config
.test_port()
.expect("unable to determine test port")
.to_string(),
)
.arg(get_pg_dbname())
.output()
.unwrap();
if !output.status.success() {
let stderr = String::from_utf8_lossy(output.stderr.as_slice());
if !stderr.contains(&format!(
"ERROR: database \"{}\" does not exist",
get_pg_dbname()
)) {
let stdout = String::from_utf8_lossy(output.stdout.as_slice());
eprintln!("unexpected error (stdout):\n{stdout}");
eprintln!("unexpected error (stderr):\n{stderr}");
panic!("failed to drop test database");
}
}
}
fn create_extension() {
let (mut client, _) = client();
client
.simple_query(&format!(
"CREATE EXTENSION {} CASCADE;",
get_extension_name()
))
.unwrap();
}
fn get_extension_name() -> String {
std::env::var("CARGO_PKG_NAME")
.unwrap_or_else(|_| panic!("CARGO_PKG_NAME environment var is unset or invalid UTF-8"))
.replace("-", "_")
}
fn get_pgdata_path() -> eyre::Result<PathBuf> {
let mut target_dir = get_target_dir()?;
target_dir.push(&format!(
"pgx-test-data-{}",
pg_sys::get_pg_major_version_num()
));
Ok(target_dir)
}
fn get_pg_dbname() -> &'static str {
"pgx_tests"
}
fn get_pg_user() -> String {
std::env::var("USER")
.unwrap_or_else(|_| panic!("USER environment var is unset or invalid UTF-8"))
}