1use std::collections::HashSet;
11use std::process::{Command, Stdio};
12
13use eyre::{eyre, WrapErr};
14use once_cell::sync::Lazy;
15use owo_colors::OwoColorize;
16use pgx::prelude::*;
17use pgx_pg_config::{createdb, get_c_locale_flags, get_target_dir, PgConfig, Pgx};
18use postgres::error::DbError;
19use std::collections::HashMap;
20use std::fmt::Write as _;
21use std::io::{BufRead, BufReader, Write};
22use std::path::PathBuf;
23use std::sync::{Arc, Mutex};
24use sysinfo::{Pid, ProcessExt, System, SystemExt};
25
26mod shutdown;
27pub use shutdown::add_shutdown_hook;
28
29type LogLines = Arc<Mutex<HashMap<String, Vec<String>>>>;
30
31struct SetupState {
32 installed: bool,
33 loglines: LogLines,
34 system_session_id: String,
35}
36
37static TEST_MUTEX: Lazy<Mutex<SetupState>> = Lazy::new(|| {
38 Mutex::new(SetupState {
39 installed: false,
40 loglines: Arc::new(Mutex::new(HashMap::new())),
41 system_session_id: "NONE".to_string(),
42 })
43});
44
45fn query_wrapper<F, T>(
50 query: Option<String>,
51 query_params: Option<&[&(dyn postgres::types::ToSql + Sync)]>,
52 mut f: F,
53) -> eyre::Result<T>
54where
55 T: IntoIterator,
56 F: FnMut(
57 Option<String>,
58 Option<&[&(dyn postgres::types::ToSql + Sync)]>,
59 ) -> Result<T, postgres::Error>,
60{
61 let result = f(query.clone(), query_params.clone());
62
63 match result {
64 Ok(result) => Ok(result),
65 Err(e) => {
66 let dberror = e.as_db_error().unwrap();
67 let query = query.unwrap();
68 let query_message = dberror.message();
69
70 let code = dberror.code().code();
71 let severity = dberror.severity();
72
73 let mut message = format!("{} SQLSTATE[{}]", severity, code).bold().red().to_string();
74
75 message.push_str(format!(": {}", query_message.bold().white()).as_str());
76 message.push_str(format!("\nquery: {}", query.bold().white()).as_str());
77 message.push_str(
78 format!(
79 "\nparams: {}",
80 match query_params {
81 Some(params) => format!("{:?}", params),
82 None => "None".to_string(),
83 }
84 )
85 .as_str(),
86 );
87
88 if let Ok(var) = std::env::var("RUST_BACKTRACE") {
89 if var.eq("1") {
90 let detail = dberror.detail().unwrap_or("None");
91 let hint = dberror.hint().unwrap_or("None");
92 let schema = dberror.hint().unwrap_or("None");
93 let table = dberror.table().unwrap_or("None");
94 let more_info = format!(
95 "\ndetail: {detail}\nhint: {hint}\nschema: {schema}\ntable: {table}"
96 );
97 message.push_str(more_info.as_str());
98 }
99 }
100
101 Err(eyre!(message))
102 }
103 }
104}
105
106pub fn run_test(
107 sql_funcname: &str,
108 expected_error: Option<&str>,
109 postgresql_conf: Vec<&'static str>,
110) -> eyre::Result<()> {
111 let (loglines, system_session_id) = initialize_test_framework(postgresql_conf)?;
112
113 let (mut client, session_id) = client()?;
114
115 let schema = "tests"; let result = match client.transaction() {
117 Ok(mut tx) => {
119 let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();"));
120
121 if result.is_ok() {
122 tx.rollback().expect("test rollback didn't work");
124 }
125
126 result
127 }
128
129 Err(e) => panic!("attempt to run test tx failed:\n{e}"),
130 };
131
132 if let Err(e) = result {
133 let error_as_string = format!("error in test tx: {e}");
134
135 let cause = e.into_source();
136 if let Some(e) = cause {
137 if let Some(dberror) = e.downcast_ref::<DbError>() {
138 let received_error_message: &str = dberror.message();
140
141 if let Some(expected_error_message) = expected_error {
142 assert_eq!(received_error_message, expected_error_message);
144 Ok(())
145 } else {
146 std::thread::sleep(std::time::Duration::from_millis(1000));
149
150 let mut pg_location = String::from("Postgres location: ");
151 pg_location.push_str(match dberror.file() {
152 Some(file) => file,
153 None => "<unknown>",
154 });
155 if let Some(ln) = dberror.line() {
156 let _ = write!(pg_location, ":{ln}");
157 };
158
159 let mut rust_location = String::from("Rust location: ");
160 rust_location.push_str(match dberror.where_() {
161 Some(place) => place,
162 None => "<unknown>",
163 });
164 panic!(
166 "\n{sys}...\n{sess}\n{e}\n{pg}\n{rs}\n\n",
167 sys = format_loglines(&system_session_id, &loglines),
168 sess = format_loglines(&session_id, &loglines),
169 e = received_error_message.bold().red(),
170 pg = pg_location.dimmed().white(),
171 rs = rust_location.yellow()
172 );
173 }
174 } else {
175 panic!("Failed downcast to DbError:\n{e}")
176 }
177 } else {
178 panic!("Error without deeper source cause:\n{e}\n", e = error_as_string.bold().red())
179 }
180 } else if let Some(message) = expected_error {
181 return Err(eyre!("Expected error: {message}"));
183 } else {
184 Ok(())
185 }
186}
187
188fn format_loglines(session_id: &str, loglines: &LogLines) -> String {
189 let mut result = String::new();
190
191 for line in loglines.lock().unwrap().entry(session_id.to_string()).or_default().iter() {
192 result.push_str(line);
193 result.push('\n');
194 }
195
196 result
197}
198
199fn initialize_test_framework(
200 postgresql_conf: Vec<&'static str>,
201) -> eyre::Result<(LogLines, String)> {
202 let mut state = TEST_MUTEX.lock().unwrap_or_else(|_| {
203 panic!(
207 "Could not obtain test mutex. A previous test may have hard-aborted while holding it."
208 );
209 });
210
211 if !state.installed {
212 shutdown::register_shutdown_hook();
213 install_extension()?;
214 initdb(postgresql_conf)?;
215
216 let system_session_id = start_pg(state.loglines.clone())?;
217 let pg_config = get_pg_config()?;
218 dropdb()?;
219 createdb(&pg_config, get_pg_dbname(), true, false)?;
220 create_extension()?;
221 state.installed = true;
222 state.system_session_id = system_session_id;
223 }
224
225 Ok((state.loglines.clone(), state.system_session_id.clone()))
226}
227
228fn get_pg_config() -> eyre::Result<PgConfig> {
229 let pgx = Pgx::from_config().wrap_err("Unable to get PGX from config")?;
230
231 let pg_version = pg_sys::get_pg_major_version_num();
232
233 let pg_config = pgx
234 .get(&format!("pg{}", pg_version))
235 .wrap_err_with(|| {
236 format!("Error getting pg_config: {} is not a valid postgres version", pg_version)
237 })
238 .unwrap()
239 .clone();
240
241 Ok(pg_config)
242}
243
244pub fn client() -> eyre::Result<(postgres::Client, String)> {
245 let pg_config = get_pg_config()?;
246 let mut client = postgres::Config::new()
247 .host(pg_config.host())
248 .port(pg_config.test_port().expect("unable to determine test port"))
249 .user(&get_pg_user())
250 .dbname(&get_pg_dbname())
251 .connect(postgres::NoTls)
252 .unwrap();
253
254 let sid_query_result = query_wrapper(
255 Some("SELECT to_hex(trunc(EXTRACT(EPOCH FROM backend_start))::integer) || '.' || to_hex(pid) AS sid FROM pg_stat_activity WHERE pid = pg_backend_pid();".to_string()),
256 Some(&[]),
257 |query, query_params| client.query(&query.unwrap(), query_params.unwrap()),
258 )
259 .wrap_err("There was an issue attempting to get the session ID from Postgres")?;
260
261 let session_id = match sid_query_result.get(0) {
262 Some(row) => row.get::<&str, &str>("sid").to_string(),
263 None => Err(eyre!("Failed to obtain a client Session ID from Postgres"))?,
264 };
265
266 query_wrapper(Some("SET log_min_messages TO 'INFO';".to_string()), None, |query, _| {
267 client.simple_query(query.unwrap().as_str())
268 })
269 .wrap_err("Postgres Client setup failed to SET log_min_messages TO 'INFO'")?;
270
271 query_wrapper(Some("SET log_min_duration_statement TO 1000;".to_string()), None, |query, _| {
272 client.simple_query(query.unwrap().as_str())
273 })
274 .wrap_err("Postgres Client setup failed to SET log_min_duration_statement TO 1000;")?;
275
276 query_wrapper(Some("SET log_statement TO 'all';".to_string()), None, |query, _| {
277 client.simple_query(query.unwrap().as_str())
278 })
279 .wrap_err("Postgres Client setup failed to SET log_statement TO 'all';")?;
280
281 Ok((client, session_id))
282}
283
284fn install_extension() -> eyre::Result<()> {
285 eprintln!("installing extension");
286 let profile = std::env::var("PGX_BUILD_PROFILE").unwrap_or("debug".into());
287 let no_schema = std::env::var("PGX_NO_SCHEMA").unwrap_or("false".into()) == "true";
288 let mut features = std::env::var("PGX_FEATURES")
289 .unwrap_or("".to_string())
290 .split_ascii_whitespace()
291 .map(|s| s.to_string())
292 .collect::<HashSet<_>>();
293 features.insert("pg_test".into());
294
295 let no_default_features =
296 std::env::var("PGX_NO_DEFAULT_FEATURES").unwrap_or("false".to_string()) == "true";
297 let all_features = std::env::var("PGX_ALL_FEATURES").unwrap_or("false".to_string()) == "true";
298
299 let pg_version = format!("pg{}", pg_sys::get_pg_major_version_string());
300 let pgx = Pgx::from_config()?;
301 let pg_config = pgx.get(&pg_version)?;
302 let cargo_test_args = get_cargo_test_features()?;
303 println!("detected cargo args: {:?}", cargo_test_args);
304
305 features.extend(cargo_test_args.features.iter().cloned());
306
307 let mut command = cargo_pgx();
308 command
309 .arg("install")
310 .arg("--test")
311 .arg("--pg-config")
312 .arg(pg_config.path().ok_or(eyre!("No pg_config found"))?)
313 .stdout(Stdio::inherit())
314 .stderr(Stdio::piped())
315 .env("CARGO_TARGET_DIR", get_target_dir()?);
316
317 if let Ok(manifest_path) = std::env::var("PGX_MANIFEST_PATH") {
318 command.arg("--manifest-path");
319 command.arg(manifest_path);
320 }
321
322 if let Ok(rust_log) = std::env::var("RUST_LOG") {
323 command.env("RUST_LOG", rust_log);
324 }
325
326 if !features.is_empty() {
327 command.arg("--features");
328 command.arg(features.into_iter().collect::<Vec<_>>().join(" "));
329 }
330
331 if no_default_features || cargo_test_args.no_default_features {
332 command.arg("--no-default-features");
333 }
334
335 if all_features || cargo_test_args.all_features {
336 command.arg("--all-features");
337 }
338
339 match profile.trim() {
340 "debug" | "dev" | "" => {}
343 "release" => {
344 command.arg("--release");
345 }
346 profile => {
347 command.args(["--profile", profile]);
348 }
349 }
350
351 if no_schema {
352 command.arg("--no-schema");
353 }
354
355 let command_str = format!("{:?}", command);
356
357 let child = command.spawn().wrap_err_with(|| {
358 format!(
359 "Failed to spawn process for installing extension using command: '{}': ",
360 command_str
361 )
362 })?;
363
364 let output = child.wait_with_output().wrap_err_with(|| {
365 format!(
366 "Failed waiting for spawned process attempting to install extension using command: '{}': ",
367 command_str
368 )
369 })?;
370
371 if !output.status.success() {
372 return Err(eyre!(
373 "Failure installing extension using command: {}\n\n{}{}",
374 command_str,
375 String::from_utf8(output.stdout).unwrap(),
376 String::from_utf8(output.stderr).unwrap()
377 ));
378 }
379
380 Ok(())
381}
382
383fn initdb(postgresql_conf: Vec<&'static str>) -> eyre::Result<()> {
384 let pgdata = get_pgdata_path()?;
385
386 if !pgdata.is_dir() {
387 let pg_config = get_pg_config()?;
388 let mut command =
389 Command::new(pg_config.initdb_path().wrap_err("unable to determine initdb path")?);
390
391 command
392 .args(get_c_locale_flags())
393 .arg("-D")
394 .arg(pgdata.to_str().unwrap())
395 .stdout(Stdio::inherit())
396 .stderr(Stdio::inherit());
397
398 let command_str = format!("{:?}", command);
399
400 let child = command.spawn().wrap_err_with(|| {
401 format!(
402 "Failed to spawn process for initializing database using command: '{}': ",
403 command_str
404 )
405 })?;
406
407 let output = child.wait_with_output().wrap_err_with(|| {
408 format!(
409 "Failed waiting for spawned process attempting to initialize database using command: '{}': ",
410 command_str
411 )
412 })?;
413
414 if !output.status.success() {
415 return Err(eyre!(
416 "Failed to initialize database using command: {}\n\n{}{}",
417 command_str,
418 String::from_utf8(output.stdout).unwrap(),
419 String::from_utf8(output.stderr).unwrap()
420 ));
421 }
422 }
423
424 modify_postgresql_conf(pgdata, postgresql_conf)
425}
426
427fn modify_postgresql_conf(pgdata: PathBuf, postgresql_conf: Vec<&'static str>) -> eyre::Result<()> {
428 let mut postgresql_conf_file = std::fs::OpenOptions::new()
429 .write(true)
430 .truncate(true)
431 .open(format!("{}/postgresql.auto.conf", pgdata.display()))
432 .wrap_err("couldn't open postgresql.auto.conf")?;
433 postgresql_conf_file
434 .write_all("log_line_prefix='[%m] [%p] [%c]: '\n".as_bytes())
435 .wrap_err("couldn't append log_line_prefix")?;
436
437 for setting in postgresql_conf {
438 postgresql_conf_file
439 .write_all(format!("{setting}\n").as_bytes())
440 .wrap_err("couldn't append custom setting to postgresql.conf")?;
441 }
442
443 postgresql_conf_file
444 .write_all(
445 format!("unix_socket_directories = '{}'", Pgx::home().unwrap().display()).as_bytes(),
446 )
447 .wrap_err("couldn't append `unix_socket_directories` setting to postgresql.conf")?;
448 Ok(())
449}
450
451fn start_pg(loglines: LogLines) -> eyre::Result<String> {
452 let pg_config = get_pg_config()?;
453 let mut command =
454 Command::new(pg_config.postmaster_path().wrap_err("unable to determine postmaster path")?);
455 command
456 .arg("-D")
457 .arg(get_pgdata_path()?.to_str().unwrap())
458 .arg("-h")
459 .arg(pg_config.host())
460 .arg("-p")
461 .arg(pg_config.test_port().expect("unable to determine test port").to_string())
462 .args(["-c", "log_destination=stderr", "-c", "logging_collector=off"])
464 .stdout(Stdio::inherit())
465 .stderr(Stdio::piped());
466
467 let command_str = format!("{command:?}");
468
469 let session_id = monitor_pg(command, command_str, loglines);
472
473 Ok(session_id)
474}
475
476fn monitor_pg(mut command: Command, cmd_string: String, loglines: LogLines) -> String {
477 let (sender, receiver) = std::sync::mpsc::channel();
478
479 std::thread::spawn(move || {
480 let mut child = command.spawn().expect("postmaster didn't spawn");
481
482 let pid = child.id();
483 add_shutdown_hook(move || unsafe {
487 libc::kill(pid as libc::pid_t, libc::SIGTERM);
488 let message_string = std::ffi::CString::new(
489 format!("stopping postgres (pid={pid})\n").bold().blue().to_string(),
490 )
491 .unwrap();
492 libc::printf("%s\0".as_ptr().cast(), message_string.as_ptr());
494 });
495
496 eprintln!("{cmd}\npid={p}", cmd = cmd_string.bold().blue(), p = pid.to_string().yellow());
497 eprintln!("{}", pg_sys::get_pg_version_string().bold().purple());
498
499 let reader = BufReader::new(child.stderr.take().expect("couldn't take postmaster stderr"));
501
502 let regex = regex::Regex::new(r#"\[.*?\] \[.*?\] \[(?P<session_id>.*?)\]"#).unwrap();
503 let mut is_started_yet = false;
504 let mut lines = reader.lines();
505 while let Some(Ok(line)) = lines.next() {
506 let session_id = match get_named_capture(®ex, "session_id", &line) {
507 Some(sid) => sid,
508 None => "NONE".to_string(),
509 };
510
511 if line.contains("database system is ready to accept connections") {
512 sender.send(session_id.clone()).unwrap();
514 is_started_yet = true;
515 }
516
517 if !is_started_yet || line.contains("TMSG: ") {
518 eprintln!("{}", line.cyan());
519 }
520
521 let mut loglines = loglines.lock().unwrap();
536 let session_lines = loglines.entry(session_id).or_insert_with(Vec::new);
537 session_lines.push(line);
538 }
539
540 match child.try_wait() {
542 Ok(status) => {
543 if let Some(_status) = status {
544 }
546 }
547 Err(e) => panic!("was going to let Postgres finish, but errored this time:\n{e}"),
548 }
549 });
550
551 receiver.recv().expect("Postgres failed to start")
554}
555
556fn dropdb() -> eyre::Result<()> {
557 let pg_config = get_pg_config()?;
558 let output = Command::new(pg_config.dropdb_path().expect("unable to determine dropdb path"))
559 .env_remove("PGDATABASE")
560 .env_remove("PGHOST")
561 .env_remove("PGPORT")
562 .env_remove("PGUSER")
563 .arg("--if-exists")
564 .arg("-h")
565 .arg(pg_config.host())
566 .arg("-p")
567 .arg(pg_config.test_port().expect("unable to determine test port").to_string())
568 .arg(get_pg_dbname())
569 .output()
570 .unwrap();
571
572 if !output.status.success() {
573 let stderr = String::from_utf8_lossy(output.stderr.as_slice());
575 if !stderr.contains(&format!("ERROR: database \"{}\" does not exist", get_pg_dbname())) {
576 let stdout = String::from_utf8_lossy(output.stdout.as_slice());
578 eprintln!("unexpected error (stdout):\n{stdout}");
579 eprintln!("unexpected error (stderr):\n{stderr}");
580 panic!("failed to drop test database");
581 }
582 }
583
584 Ok(())
585}
586
587fn create_extension() -> eyre::Result<()> {
588 let (mut client, _) = client()?;
589 let extension_name = get_extension_name();
590
591 query_wrapper(
592 Some(format!("CREATE EXTENSION {} CASCADE;", &extension_name)),
593 None,
594 |query, _| client.simple_query(query.unwrap().as_str()),
595 )
596 .wrap_err(format!(
597 "There was an issue creating the extension '{}' in Postgres: ",
598 &extension_name
599 ))?;
600
601 Ok(())
602}
603
604fn get_extension_name() -> String {
605 std::env::var("CARGO_PKG_NAME")
606 .unwrap_or_else(|_| panic!("CARGO_PKG_NAME environment var is unset or invalid UTF-8"))
607 .replace("-", "_")
608}
609
610fn get_pgdata_path() -> eyre::Result<PathBuf> {
611 let mut target_dir = get_target_dir()?;
612 target_dir.push(&format!("pgx-test-data-{}", pg_sys::get_pg_major_version_num()));
613 Ok(target_dir)
614}
615
616pub(crate) fn get_pg_dbname() -> &'static str {
617 "pgx_tests"
618}
619
620pub(crate) fn get_pg_user() -> String {
621 std::env::var("USER")
622 .unwrap_or_else(|_| panic!("USER environment var is unset or invalid UTF-8"))
623}
624
625pub fn get_named_capture(
626 regex: ®ex::Regex,
627 name: &'static str,
628 against: &str,
629) -> Option<String> {
630 match regex.captures(against) {
631 Some(cap) => Some(cap[name].to_string()),
632 None => None,
633 }
634}
635
636fn get_cargo_test_features() -> eyre::Result<clap_cargo::Features> {
637 let mut features = clap_cargo::Features::default();
638 let cargo_user_args = get_cargo_args();
639 let mut iter = cargo_user_args.iter();
640 while let Some(part) = iter.next() {
641 match part.as_str() {
642 "--no-default-features" => features.no_default_features = true,
643 "--features" => {
644 let configured_features = iter.next().ok_or(eyre!(
645 "no `--features` specified in the cargo argument list: {:?}",
646 cargo_user_args
647 ))?;
648 features.features = configured_features
649 .split(|c: char| c.is_ascii_whitespace() || c == ',')
650 .map(|s| s.to_string())
651 .collect();
652 }
653 "--all-features" => features.all_features = true,
654 _ => {}
655 }
656 }
657
658 Ok(features)
659}
660
661fn get_cargo_args() -> Vec<String> {
662 let mut system = System::new_all();
664 system.refresh_all();
665
666 let mut pid = Pid::from(std::process::id() as usize);
678 while let Some(process) = system.process(pid) {
679 if process.exe().ends_with("cargo") {
682 if process.cmd().iter().any(|arg| arg == "test")
684 && !process.cmd().iter().any(|arg| arg == "pgx")
685 {
686 return process.cmd().iter().cloned().collect();
688 }
689 }
690
691 match process.parent() {
693 Some(parent_pid) => pid = parent_pid,
694 None => break,
695 }
696 }
697
698 Vec::new()
699}
700
701fn cargo_pgx() -> std::process::Command {
704 fn var_path(s: &str) -> Option<PathBuf> {
705 std::env::var_os(s).map(PathBuf::from)
706 }
707 let cargo_pgx = var_path("CARGO_PGX")
710 .or_else(|| find_on_path("cargo-pgx"))
711 .or_else(|| var_path("CARGO"))
712 .unwrap_or_else(|| "cargo".into());
713 let mut cmd = std::process::Command::new(cargo_pgx);
714 cmd.arg("pgx");
715 cmd
716}
717
718fn find_on_path(program: &str) -> Option<PathBuf> {
719 assert!(!program.contains('/'));
720 let paths = std::env::var_os("PATH")?;
723 std::env::split_paths(&paths).map(|p| p.join(program)).find(|abs| abs.exists())
724}