gel_pg_captive/
lib.rs

1use ephemeral_port::EphemeralPort;
2use gel_auth::AuthType;
3use gel_stream::ResolvedTarget;
4use std::io::{BufReader, Write};
5use std::net::{Ipv4Addr, SocketAddr};
6use std::num::NonZeroUsize;
7use std::path::{Path, PathBuf};
8use std::process::{Command, Stdio};
9use std::time::{Duration, Instant};
10use stdio_reader::StdioReader;
11use tempfile::TempDir;
12
13mod ephemeral_port;
14mod stdio_reader;
15
16// Constants
17pub const STARTUP_TIMEOUT_DURATION: Duration = Duration::from_secs(30);
18pub const PORT_RELEASE_TIMEOUT: Duration = Duration::from_secs(30);
19pub const LINGER_DURATION: Duration = Duration::from_secs(1);
20pub const HOT_LOOP_INTERVAL: Duration = Duration::from_millis(100);
21pub const DEFAULT_USERNAME: &str = "username";
22pub const DEFAULT_PASSWORD: &str = "password";
23pub const DEFAULT_DATABASE: &str = "postgres";
24
25use std::collections::HashMap;
26
27#[derive(Debug, Clone, Default)]
28pub enum PostgresBinPath {
29    #[default]
30    Path,
31    Specified(PathBuf),
32}
33
34#[derive(Debug, Clone)]
35pub struct PostgresBuilder {
36    auth: AuthType,
37    bin_path: PostgresBinPath,
38    data_dir: Option<PathBuf>,
39    server_options: HashMap<String, String>,
40    ssl_cert_and_key: Option<(String, String)>,
41    unix_enabled: bool,
42    debug_level: Option<u8>,
43    standby_of_port: Option<u16>,
44}
45
46impl Default for PostgresBuilder {
47    fn default() -> Self {
48        Self {
49            auth: AuthType::Trust,
50            bin_path: PostgresBinPath::default(),
51            data_dir: None,
52            server_options: HashMap::new(),
53            ssl_cert_and_key: None,
54            unix_enabled: false,
55            debug_level: None,
56            standby_of_port: None,
57        }
58    }
59}
60
61impl PostgresBuilder {
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    /// Attempt to configure the builder to use the default postgres binaries.
67    /// Returns an error if the binaries are not found.
68    pub fn with_automatic_bin_path(mut self) -> std::io::Result<Self> {
69        let bindir = postgres_bin_dir()?;
70        self.bin_path = PostgresBinPath::Specified(bindir);
71        Ok(self)
72    }
73
74    /// Configures the builder with a quick networking mode.
75    pub fn with_automatic_mode(mut self, mode: Mode) -> Self {
76        match mode {
77            Mode::Tcp => {
78                // No special configuration needed for TCP mode
79            }
80            Mode::TcpSsl => {
81                use gel_stream::test_keys::raw::*;
82                self.ssl_cert_and_key = Some((SERVER_CERT.to_string(), SERVER_KEY.to_string()));
83            }
84            Mode::Unix => {
85                self.unix_enabled = true;
86            }
87        }
88        self
89    }
90
91    pub fn auth(mut self, auth: AuthType) -> Self {
92        self.auth = auth;
93        self
94    }
95
96    pub fn bin_path(mut self, bin_path: impl AsRef<Path>) -> Self {
97        self.bin_path = PostgresBinPath::Specified(bin_path.as_ref().to_path_buf());
98        self
99    }
100
101    pub fn data_dir(mut self, data_dir: PathBuf) -> Self {
102        self.data_dir = Some(data_dir);
103        self
104    }
105
106    pub fn debug_level(mut self, debug_level: u8) -> Self {
107        self.debug_level = Some(debug_level);
108        self
109    }
110
111    pub fn server_option(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
112        self.server_options
113            .insert(key.as_ref().to_string(), value.as_ref().to_string());
114        self
115    }
116
117    pub fn server_options(
118        mut self,
119        server_options: impl IntoIterator<Item = (impl AsRef<str>, impl AsRef<str>)>,
120    ) -> Self {
121        for (key, value) in server_options {
122            self.server_options
123                .insert(key.as_ref().to_string(), value.as_ref().to_string());
124        }
125        self
126    }
127
128    pub fn enable_ssl(mut self, cert: String, key: String) -> Self {
129        self.ssl_cert_and_key = Some((cert, key));
130        self
131    }
132
133    pub fn enable_unix(mut self) -> Self {
134        self.unix_enabled = true;
135        self
136    }
137
138    pub fn enable_standby_of(mut self, port: u16) -> Self {
139        self.standby_of_port = Some(port);
140        self
141    }
142
143    pub fn build(self) -> std::io::Result<PostgresProcess> {
144        let initdb = match &self.bin_path {
145            PostgresBinPath::Path => "initdb".into(),
146            PostgresBinPath::Specified(path) => path.join("initdb"),
147        };
148        let postgres = match &self.bin_path {
149            PostgresBinPath::Path => "postgres".into(),
150            PostgresBinPath::Specified(path) => path.join("postgres"),
151        };
152        let pg_basebackup = match &self.bin_path {
153            PostgresBinPath::Path => "pg_basebackup".into(),
154            PostgresBinPath::Specified(path) => path.join("pg_basebackup"),
155        };
156
157        if !initdb.exists() {
158            return Err(std::io::Error::new(
159                std::io::ErrorKind::NotFound,
160                format!("initdb executable not found at {}", initdb.display()),
161            ));
162        }
163        if !postgres.exists() {
164            return Err(std::io::Error::new(
165                std::io::ErrorKind::NotFound,
166                format!("postgres executable not found at {}", postgres.display()),
167            ));
168        }
169        if !pg_basebackup.exists() {
170            return Err(std::io::Error::new(
171                std::io::ErrorKind::NotFound,
172                format!(
173                    "pg_basebackup executable not found at {}",
174                    pg_basebackup.display()
175                ),
176            ));
177        }
178
179        let temp_dir = TempDir::new()?;
180        let port = EphemeralPort::allocate()?;
181        let data_dir = self
182            .data_dir
183            .unwrap_or_else(|| temp_dir.path().join("data"));
184
185        // Create a standby signal file if requested
186        if let Some(standby_of_port) = self.standby_of_port {
187            run_pgbasebackup(&pg_basebackup, &data_dir, "localhost", standby_of_port)?;
188            let standby_signal_path = data_dir.join("standby.signal");
189            std::fs::write(&standby_signal_path, "")?;
190        } else {
191            init_postgres(&initdb, &data_dir, self.auth)?;
192        }
193
194        let port = port.take();
195
196        let ssl_config = self.ssl_cert_and_key;
197
198        let (socket_address, socket_path) = if self.unix_enabled {
199            #[cfg(windows)]
200            unreachable!("Unix mode is not supported on Windows");
201            #[cfg(unix)]
202            (
203                ResolvedTarget::try_from(get_unix_socket_path(&data_dir, port))?,
204                Some(&data_dir),
205            )
206        } else {
207            (
208                ResolvedTarget::SocketAddr(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)),
209                None::<&PathBuf>,
210            )
211        };
212
213        let tcp_address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port);
214
215        let mut command = Command::new(postgres);
216        command
217            .stdout(Stdio::piped())
218            .stderr(Stdio::piped())
219            .arg("-D")
220            .arg(&data_dir)
221            .arg("-h")
222            .arg(Ipv4Addr::LOCALHOST.to_string())
223            .arg("-F")
224            .arg("-p")
225            .arg(port.to_string());
226
227        if let Some(socket_path) = &socket_path {
228            command.arg("-k").arg(socket_path);
229        }
230
231        for (key, value) in self.server_options {
232            command.arg("-c").arg(format!("{}={}", key, value));
233        }
234
235        if let Some(debug_level) = self.debug_level {
236            command.arg("-d").arg(debug_level.to_string());
237        }
238
239        let child = run_postgres(command, &data_dir, socket_path, ssl_config, port)?;
240
241        Ok(PostgresProcess {
242            child: Some(child),
243            socket_address,
244            tcp_address,
245            temp_dir,
246        })
247    }
248}
249
250fn spawn(command: &mut Command) -> std::io::Result<()> {
251    command.stdout(Stdio::piped());
252    command.stderr(Stdio::piped());
253
254    let program = Path::new(command.get_program())
255        .file_name()
256        .unwrap_or_default()
257        .to_string_lossy()
258        .to_string();
259
260    eprintln!("{program} command:\n  {:?}", command);
261    let command = command.spawn()?;
262    let output = std::thread::scope(|s| {
263        #[cfg(unix)]
264        use nix::{
265            sys::signal::{self, Signal},
266            unistd::Pid,
267        };
268
269        #[cfg(unix)]
270        let pid = Pid::from_raw(command.id() as _);
271
272        let handle = s.spawn(|| command.wait_with_output());
273        let start = Instant::now();
274        while start.elapsed() < Duration::from_secs(30) {
275            if handle.is_finished() {
276                let handle = handle
277                    .join()
278                    .map_err(|e| std::io::Error::other(format!("{e:?}")))??;
279                return Ok(handle);
280            }
281            std::thread::sleep(HOT_LOOP_INTERVAL);
282        }
283
284        #[cfg(unix)]
285        {
286            eprintln!("Command timed out after 30 seconds. Sending SIGKILL.");
287            signal::kill(pid, Signal::SIGKILL)?;
288        }
289        handle
290            .join()
291            .map_err(|e| std::io::Error::other(format!("{e:?}")))?
292    })?;
293    eprintln!("{program}: {}", output.status);
294    let status = output.status;
295    let output_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
296    let error_str = String::from_utf8_lossy(&output.stderr).trim().to_string();
297
298    if !output_str.is_empty() {
299        eprintln!("=== begin {} stdout:===", program);
300        eprintln!("{}", output_str);
301        if !output_str.ends_with('\n') {
302            eprintln!();
303        }
304        eprintln!("=== end {} stdout ===", program);
305    }
306    if !error_str.is_empty() {
307        eprintln!("=== begin {} stderr:===", program);
308        eprintln!("{}", error_str);
309        if !error_str.ends_with('\n') {
310            eprintln!();
311        }
312        eprintln!("=== end {} stderr ===", program);
313    }
314    if output_str.is_empty() && error_str.is_empty() {
315        eprintln!("{program}: No output\n");
316    }
317    if !status.success() {
318        return Err(std::io::Error::other(format!(
319            "{program} failed with: {}",
320            status
321        )));
322    }
323
324    Ok(())
325}
326
327fn init_postgres(initdb: &Path, data_dir: &Path, auth: AuthType) -> std::io::Result<()> {
328    let mut pwfile = tempfile::NamedTempFile::new()?;
329    writeln!(pwfile, "{}", DEFAULT_PASSWORD)?;
330    let mut command = Command::new(initdb);
331    command
332        .arg("-D")
333        .arg(data_dir)
334        .arg("-A")
335        .arg(match auth {
336            AuthType::Deny => "reject",
337            AuthType::Trust => "trust",
338            AuthType::Plain => "password",
339            AuthType::Md5 => "md5",
340            AuthType::ScramSha256 => "scram-sha-256",
341        })
342        .arg("--pwfile")
343        .arg(pwfile.path())
344        .arg("-U")
345        .arg(DEFAULT_USERNAME)
346        .arg("--no-instructions");
347
348    spawn(&mut command)?;
349
350    Ok(())
351}
352
353fn run_pgbasebackup(
354    pg_basebackup: &Path,
355    data_dir: &Path,
356    host: &str,
357    port: u16,
358) -> std::io::Result<()> {
359    let mut command = Command::new(pg_basebackup);
360    // This works for testing purposes but putting passwords in the environment
361    // is usually bad practice.
362    //
363    // "Use of this environment variable is not recommended for security
364    // reasons" <https://www.postgresql.org/docs/current/libpq-envars.html>
365    command.env("PGPASSWORD", DEFAULT_PASSWORD);
366    command
367        .arg("-D")
368        .arg(data_dir)
369        .arg("-h")
370        .arg(host)
371        .arg("-p")
372        .arg(port.to_string())
373        .arg("-U")
374        .arg(DEFAULT_USERNAME)
375        .arg("-X")
376        .arg("stream")
377        .arg("-w");
378
379    spawn(&mut command)?;
380    Ok(())
381}
382
383fn run_postgres(
384    mut command: Command,
385    data_dir: &Path,
386    socket_path: Option<impl AsRef<Path>>,
387    ssl: Option<(String, String)>,
388    port: u16,
389) -> std::io::Result<std::process::Child> {
390    let socket_path = socket_path.map(|path| path.as_ref().to_owned());
391
392    if let Some((cert_pem, key_pem)) = ssl {
393        let postgres_cert_path = data_dir.join("server.crt");
394        let postgres_key_path = data_dir.join("server.key");
395        std::fs::write(&postgres_cert_path, cert_pem)?;
396        std::fs::write(&postgres_key_path, key_pem)?;
397
398        #[cfg(unix)]
399        {
400            use std::os::unix::fs::PermissionsExt;
401            // Set permissions for the certificate and key files
402            std::fs::set_permissions(&postgres_cert_path, std::fs::Permissions::from_mode(0o600))?;
403            std::fs::set_permissions(&postgres_key_path, std::fs::Permissions::from_mode(0o600))?;
404        }
405
406        // Edit pg_hba.conf to change all "host" line prefixes to "hostssl"
407        let pg_hba_path = data_dir.join("pg_hba.conf");
408        let content = std::fs::read_to_string(&pg_hba_path)?;
409        let modified_content = content
410            .lines()
411            .filter(|line| !line.starts_with("#") && !line.is_empty())
412            .map(|line| {
413                if line.trim_start().starts_with("host") {
414                    line.replacen("host", "hostssl", 1)
415                } else {
416                    line.to_string()
417                }
418            })
419            .collect::<Vec<String>>()
420            .join("\n");
421        eprintln!("pg_hba.conf:\n==========\n{modified_content}\n==========");
422        std::fs::write(&pg_hba_path, modified_content)?;
423
424        command.arg("-l");
425    }
426
427    eprintln!("postgres command:\n  {:?}", command);
428    let mut child = command.spawn()?;
429
430    let stdout_reader = BufReader::new(child.stdout.take().expect("Failed to capture stdout"));
431    let _ = StdioReader::spawn(stdout_reader, format!("pg_stdout {}", child.id()));
432    let stderr_reader = BufReader::new(child.stderr.take().expect("Failed to capture stderr"));
433    let stderr_reader = StdioReader::spawn(stderr_reader, format!("pg_stderr {}", child.id()));
434
435    let start_time = Instant::now();
436
437    let mut tcp_socket: Option<std::net::TcpStream> = None;
438    #[cfg(unix)]
439    let mut unix_socket: Option<std::os::unix::net::UnixStream> = None;
440    #[cfg(unix)]
441    let unix_socket_path = socket_path.map(|path| get_unix_socket_path(path, port));
442    let tcp_socket_addr = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, port));
443
444    let mut db_ready = false;
445    let mut network_ready = false;
446
447    while start_time.elapsed() < STARTUP_TIMEOUT_DURATION && !network_ready {
448        std::thread::sleep(HOT_LOOP_INTERVAL);
449        match child.try_wait() {
450            Ok(Some(status)) => {
451                return Err(std::io::Error::other(format!(
452                    "PostgreSQL exited with status: {}",
453                    status
454                )))
455            }
456            Err(e) => return Err(e),
457            _ => {}
458        }
459        if !db_ready && stderr_reader.contains("database system is ready to accept ") {
460            eprintln!("Database is ready");
461            db_ready = true;
462        } else {
463            continue;
464        }
465        #[cfg(unix)]
466        if let Some(unix_socket_path) = &unix_socket_path {
467            if unix_socket.is_none() {
468                unix_socket = std::os::unix::net::UnixStream::connect(unix_socket_path).ok();
469            }
470        }
471        if tcp_socket.is_none() {
472            tcp_socket = std::net::TcpStream::connect(tcp_socket_addr).ok();
473        }
474
475        #[cfg(unix)]
476        {
477            network_ready =
478                (unix_socket_path.is_none() || unix_socket.is_some()) && tcp_socket.is_some();
479        }
480        #[cfg(not(unix))]
481        {
482            network_ready = tcp_socket.is_some();
483        }
484    }
485
486    // Print status for TCP/unix sockets
487    if let Some(tcp) = &tcp_socket {
488        eprintln!(
489            "TCP socket at {tcp_socket_addr:?} bound successfully (local address was {})",
490            tcp.local_addr()?
491        );
492    } else {
493        eprintln!("TCP socket at {tcp_socket_addr:?} binding failed");
494    }
495
496    #[cfg(unix)]
497    if let Some(unix_socket_path) = &unix_socket_path {
498        if unix_socket.is_some() {
499            eprintln!("Unix socket at {unix_socket_path:?} connected successfully");
500        } else {
501            eprintln!("Unix socket at {unix_socket_path:?} connection failed");
502        }
503    }
504
505    if network_ready {
506        return Ok(child);
507    }
508
509    Err(std::io::Error::new(
510        std::io::ErrorKind::TimedOut,
511        "PostgreSQL failed to start within 30 seconds",
512    ))
513}
514
515fn postgres_bin_dir() -> std::io::Result<std::path::PathBuf> {
516    let portable_bin_path = std::env::home_dir()
517        .ok_or(std::io::Error::new(
518            std::io::ErrorKind::NotFound,
519            "Home directory not found",
520        ))?
521        .join(".local/share/edgedb/portable");
522    eprintln!("Portable path: {portable_bin_path:?}");
523    let mut versions = Vec::new();
524    for entry in std::fs::read_dir(portable_bin_path)?.flatten() {
525        let path = entry.path().join("bin").to_path_buf();
526        if path.exists() {
527            eprintln!("Found postgres bin path: {path:?}");
528            versions.push(path);
529        }
530    }
531
532    versions.sort();
533    let latest = versions.iter().next_back().ok_or(std::io::Error::new(
534        std::io::ErrorKind::NotFound,
535        "No postgres versions found",
536    ))?;
537
538    Ok(latest.to_path_buf())
539}
540
541fn get_unix_socket_path(socket_path: impl AsRef<Path>, port: u16) -> PathBuf {
542    socket_path.as_ref().join(format!(".s.PGSQL.{}", port))
543}
544
545#[derive(Debug, Clone, Copy)]
546pub enum Mode {
547    Tcp,
548    TcpSsl,
549    Unix,
550}
551
552/// The signal to send to the server to shut it down.
553///
554/// <https://www.postgresql.org/docs/8.1/postmaster-shutdown.html>
555#[derive(Debug, Clone, Copy)]
556pub enum ShutdownSignal {
557    /// "After receiving SIGTERM, the server disallows new connections, but lets
558    /// existing sessions end their work normally. It shuts down only after all
559    /// of the sessions terminate normally. This is the Smart Shutdown."
560    Smart,
561    /// "The server disallows new connections and sends all existing server
562    /// processes SIGTERM, which will cause them to abort their current
563    /// transactions and exit promptly. It then waits for the server processes
564    /// to exit and finally shuts down. This is the Fast Shutdown."
565    Fast,
566    /// "This is the Immediate Shutdown, which will cause the postmaster process
567    /// to send a SIGQUIT to all child processes and exit immediately, without
568    /// properly shutting itself down. The child processes likewise exit
569    /// immediately upon receiving SIGQUIT. This will lead to recovery (by
570    /// replaying the WAL log) upon next start-up. This is recommended only in
571    /// emergencies."
572    Immediate,
573    /// "It is best not to use SIGKILL to shut down the server. Doing so will
574    /// prevent the server from releasing shared memory and semaphores, which
575    /// may then have to be done manually before a new server can be started.
576    /// Furthermore, SIGKILL kills the postmaster process without letting it
577    /// relay the signal to its subprocesses, so it will be necessary to kill
578    /// the individual subprocesses by hand as well."
579    Forceful,
580}
581
582#[derive(Debug)]
583pub struct PostgresCluster {
584    primary: PostgresProcess,
585    standbys: Vec<PostgresProcess>,
586}
587
588impl PostgresCluster {
589    #[cfg(unix)]
590    pub fn shutdown_timeout(
591        self,
592        timeout: Duration,
593        signal: ShutdownSignal,
594    ) -> Result<(), Vec<PostgresProcess>> {
595        let mut failed = Vec::new();
596        for standby in self.standbys {
597            if let Err(e) = standby.shutdown_timeout(timeout, signal) {
598                failed.push(e);
599            }
600        }
601        if let Err(e) = self.primary.shutdown_timeout(timeout, signal) {
602            failed.push(e);
603        }
604        if failed.is_empty() {
605            Ok(())
606        } else {
607            Err(failed)
608        }
609    }
610}
611
612#[derive(Debug)]
613pub struct PostgresProcess {
614    child: Option<std::process::Child>,
615    pub socket_address: ResolvedTarget,
616    pub tcp_address: SocketAddr,
617    #[allow(unused)]
618    temp_dir: TempDir,
619}
620
621impl PostgresProcess {
622    fn child(&self) -> &std::process::Child {
623        self.child.as_ref().unwrap()
624    }
625
626    fn child_mut(&mut self) -> &mut std::process::Child {
627        self.child.as_mut().unwrap()
628    }
629
630    #[cfg(unix)]
631    pub fn notify_shutdown(&mut self, signal: ShutdownSignal) -> std::io::Result<()> {
632        use nix::sys::signal::{self, Signal};
633        use nix::unistd::Pid;
634
635        let id = Pid::from_raw(self.child().id() as _);
636        // https://www.postgresql.org/docs/8.1/postmaster-shutdown.html
637        match signal {
638            ShutdownSignal::Smart => signal::kill(id, Signal::SIGTERM)?,
639            ShutdownSignal::Fast => signal::kill(id, Signal::SIGINT)?,
640            ShutdownSignal::Immediate => signal::kill(id, Signal::SIGQUIT)?,
641            ShutdownSignal::Forceful => signal::kill(id, Signal::SIGKILL)?,
642        }
643        Ok(())
644    }
645
646    pub fn try_wait(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
647        self.child_mut().try_wait()
648    }
649
650    /// Try to shut down, waiting up to `timeout` for the process to exit.
651    #[cfg(unix)]
652    pub fn shutdown_timeout(
653        mut self,
654        timeout: Duration,
655        signal: ShutdownSignal,
656    ) -> Result<std::process::ExitStatus, Self> {
657        _ = self.notify_shutdown(signal);
658
659        let id = self.child().id();
660
661        let start = Instant::now();
662        while start.elapsed() < timeout {
663            if let Ok(Some(exit)) = self.child_mut().try_wait() {
664                self.child = None;
665                eprintln!("Process {id} died gracefully. ({exit:?})");
666                return Ok(exit);
667            }
668            std::thread::sleep(HOT_LOOP_INTERVAL);
669        }
670        Err(self)
671    }
672}
673
674#[cfg(unix)]
675impl Drop for PostgresProcess {
676    fn drop(&mut self) {
677        use nix::sys::signal::{self, Signal};
678        use nix::unistd::Pid;
679
680        let Some(mut child) = self.child.take() else {
681            return;
682        };
683
684        // Create a thread to send SIGQUIT to the child process. The thread will not block
685        // process exit.
686
687        let id = Pid::from_raw(child.id() as _);
688        eprintln!("Shutting down Postgres process with pid {id}");
689        if let Ok(Some(_)) = child.try_wait() {
690            eprintln!("Process {id} already exited (crashed?).");
691            return;
692        }
693        if let Err(e) = signal::kill(id, Signal::SIGQUIT) {
694            eprintln!("Failed to send SIGQUIT to process {id}: {e:?}");
695        }
696
697        let builder = std::thread::Builder::new().name("postgres-shutdown-signal".into());
698        builder
699            .spawn(move || {
700                // Instead of sleeping, loop and check if the child process has exited every 100ms for up to 10 seconds.
701                let start = Instant::now();
702                while start.elapsed() < std::time::Duration::from_secs(10) {
703                    if let Ok(Some(_)) = child.try_wait() {
704                        eprintln!("Process {id} died gracefully.");
705                        return;
706                    }
707                    std::thread::sleep(HOT_LOOP_INTERVAL);
708                }
709                eprintln!("Process {id} did not die gracefully. Sending SIGKILL.");
710                _ = signal::kill(id, Signal::SIGKILL);
711            })
712            .unwrap();
713    }
714}
715
716/// Creates and runs a new Postgres server process in a temporary directory.
717pub fn setup_postgres(auth: AuthType, mode: Mode) -> std::io::Result<Option<PostgresProcess>> {
718    let builder: PostgresBuilder = PostgresBuilder::new();
719
720    let Ok(mut builder) = builder.with_automatic_bin_path() else {
721        eprintln!("Skipping test: postgres bin dir not found");
722        return Ok(None);
723    };
724
725    builder = builder.auth(auth).with_automatic_mode(mode);
726
727    let process = builder.build()?;
728    Ok(Some(process))
729}
730
731pub fn create_cluster(
732    auth: AuthType,
733    size: NonZeroUsize,
734) -> std::io::Result<Option<PostgresCluster>> {
735    let builder: PostgresBuilder = PostgresBuilder::new();
736
737    let Ok(mut builder) = builder.with_automatic_bin_path() else {
738        eprintln!("Skipping test: postgres bin dir not found");
739        return Ok(None);
740    };
741
742    builder = builder.auth(auth).with_automatic_mode(Mode::Tcp);
743
744    // Primary requires the following postgres settings:
745    // - wal_level = replica
746
747    let primary = builder
748        .clone()
749        .server_option("wal_level", "replica")
750        .build()?;
751    let primary_port = primary.tcp_address.port();
752
753    let mut cluster = PostgresCluster {
754        primary,
755        standbys: vec![],
756    };
757
758    // Standby requires the following postgres settings:
759    // - primary_conninfo = 'host=localhost port=<port> user=postgres password=password'
760    // - hot_standby = on
761
762    for _ in 0..size.get() - 1 {
763        let builder = builder.clone()
764            .server_option("primary_conninfo", format!("host=localhost port={primary_port} user={DEFAULT_USERNAME} password={DEFAULT_PASSWORD}"))
765            .server_option("hot_standby", "on")
766            .enable_standby_of(primary_port);
767        let standby = builder.build()?;
768        cluster.standbys.push(standby);
769    }
770
771    Ok(Some(cluster))
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use std::{num::NonZeroUsize, path::PathBuf};
778
779    #[test]
780    fn test_builder_defaults() {
781        let builder = PostgresBuilder::new();
782        assert!(matches!(builder.auth, AuthType::Trust));
783        assert!(matches!(builder.bin_path, PostgresBinPath::Path));
784        assert!(builder.data_dir.is_none());
785        assert_eq!(builder.server_options.len(), 0);
786    }
787
788    #[test]
789    fn test_builder_customization() {
790        let mut options = HashMap::new();
791        options.insert("max_connections", "100");
792
793        let data_dir = PathBuf::from("/tmp/pg_data");
794        let bin_path = PathBuf::from("/usr/local/pgsql/bin");
795
796        let builder = PostgresBuilder::new()
797            .auth(AuthType::Md5)
798            .bin_path(bin_path)
799            .data_dir(data_dir.clone())
800            .server_options(options);
801
802        assert!(matches!(builder.auth, AuthType::Md5));
803        assert!(matches!(builder.bin_path, PostgresBinPath::Specified(_)));
804        assert_eq!(builder.data_dir.unwrap(), data_dir);
805        assert_eq!(
806            builder.server_options.get("max_connections").unwrap(),
807            "100"
808        );
809    }
810
811    #[test]
812    #[cfg(unix)]
813    fn test_create_cluster() {
814        let Some(cluster) = create_cluster(AuthType::Md5, NonZeroUsize::new(2).unwrap()).unwrap()
815        else {
816            return;
817        };
818        assert_eq!(cluster.standbys.len(), 1);
819        cluster
820            .shutdown_timeout(Duration::from_secs(10), ShutdownSignal::Smart)
821            .unwrap();
822    }
823}