postgresql_embedded/
postgresql.rs

1use crate::error::Error::{DatabaseInitializationError, DatabaseStartError, DatabaseStopError};
2use crate::error::Result;
3use crate::settings::{BOOTSTRAP_DATABASE, BOOTSTRAP_SUPERUSER, Settings};
4use postgresql_archive::get_version;
5use postgresql_archive::{ExactVersion, ExactVersionReq};
6use postgresql_archive::{extract, get_archive};
7#[cfg(feature = "tokio")]
8use postgresql_commands::AsyncCommandExecutor;
9use postgresql_commands::CommandBuilder;
10#[cfg(not(feature = "tokio"))]
11use postgresql_commands::CommandExecutor;
12use postgresql_commands::initdb::InitDbBuilder;
13use postgresql_commands::pg_ctl::Mode::{Start, Stop};
14use postgresql_commands::pg_ctl::PgCtlBuilder;
15use postgresql_commands::pg_ctl::ShutdownMode::Fast;
16use semver::Version;
17use sqlx::{PgPool, Row};
18use std::fs::{read_dir, remove_dir_all, remove_file};
19use std::io::prelude::*;
20use std::net::TcpListener;
21use std::path::PathBuf;
22use tracing::{debug, instrument};
23
24use crate::Error::{CreateDatabaseError, DatabaseExistsError, DropDatabaseError};
25
26const PGDATABASE: &str = "PGDATABASE";
27
28/// `PostgreSQL` status
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum Status {
31    /// Archive not installed
32    NotInstalled,
33    /// Installation complete; not initialized
34    Installed,
35    /// Server started
36    Started,
37    /// Server initialized and stopped
38    Stopped,
39}
40
41/// `PostgreSQL` server
42#[derive(Clone, Debug)]
43pub struct PostgreSQL {
44    settings: Settings,
45}
46
47/// `PostgreSQL` server methods
48impl PostgreSQL {
49    /// Create a new [`PostgreSQL`] instance
50    #[must_use]
51    pub fn new(settings: Settings) -> Self {
52        let mut postgresql = PostgreSQL { settings };
53
54        // If an exact version is set, append the version to the installation directory to avoid
55        // conflicts with other versions.  This will also facilitate setting the status of the
56        // server to the correct initial value.  If the minor and release version are not set, the
57        // installation directory will be determined dynamically during the installation process.
58        if let Some(version) = postgresql.settings.version.exact_version() {
59            let path = &postgresql.settings.installation_dir;
60            let version_string = version.to_string();
61
62            if !path.ends_with(&version_string) {
63                postgresql.settings.installation_dir =
64                    postgresql.settings.installation_dir.join(version_string);
65            }
66        }
67
68        postgresql
69    }
70
71    /// Get the [status](Status) of the PostgreSQL server
72    #[instrument(level = "debug", skip(self))]
73    pub fn status(&self) -> Status {
74        if self.is_running() {
75            Status::Started
76        } else if self.is_initialized() {
77            Status::Stopped
78        } else if self.installed_dir().is_some() {
79            Status::Installed
80        } else {
81            Status::NotInstalled
82        }
83    }
84
85    /// Get the [settings](Settings) of the `PostgreSQL` server
86    #[must_use]
87    pub fn settings(&self) -> &Settings {
88        &self.settings
89    }
90
91    /// Find a directory where `PostgreSQL` server is installed.
92    /// This first checks if the installation directory exists and matches the version requirement.
93    /// If it doesn't, it will search all the child directories for the latest version that matches the requirement.
94    /// If it returns None, we couldn't find a matching installation.
95    fn installed_dir(&self) -> Option<PathBuf> {
96        let path = &self.settings.installation_dir;
97        let maybe_path_version = path
98            .file_name()
99            .and_then(|file_name| Version::parse(&file_name.to_string_lossy()).ok());
100        // If this directory matches the version requirement, we're done.
101        if let Some(path_version) = maybe_path_version {
102            if self.settings.version.matches(&path_version) && path.exists() {
103                return Some(path.clone());
104            }
105        }
106
107        // Get all directories in the path as versions.
108        let mut versions = read_dir(path)
109            .ok()?
110            .filter_map(|entry| {
111                let Some(entry) = entry.ok() else {
112                    // We ignore filesystem errors.
113                    return None;
114                };
115                // Skip non-directories
116                if !entry.file_type().ok()?.is_dir() {
117                    return None;
118                }
119                let file_name = entry.file_name();
120                let version = Version::parse(&file_name.to_string_lossy()).ok()?;
121                if self.settings.version.matches(&version) {
122                    Some((version, entry.path()))
123                } else {
124                    None
125                }
126            })
127            .collect::<Vec<_>>();
128        // Sort the versions in descending order i.e. latest version first
129        versions.sort_by(|(a, _), (b, _)| b.cmp(a));
130        // Get the first matching version as the best match
131        let version_path = versions.first().map(|(_, path)| path.clone());
132        version_path
133    }
134
135    /// Check if the `PostgreSQL` server is initialized
136    fn is_initialized(&self) -> bool {
137        self.settings.data_dir.join("postgresql.conf").exists()
138    }
139
140    /// Check if the `PostgreSQL` server is running
141    fn is_running(&self) -> bool {
142        let pid_file = self.settings.data_dir.join("postmaster.pid");
143        pid_file.exists()
144    }
145
146    /// Set up the database by extracting the archive and initializing the database.
147    /// If the installation directory already exists, the archive will not be extracted.
148    /// If the data directory already exists, the database will not be initialized.
149    #[instrument(skip(self))]
150    pub async fn setup(&mut self) -> Result<()> {
151        match self.installed_dir() {
152            Some(installed_dir) => {
153                self.settings.installation_dir = installed_dir;
154            }
155            None => {
156                self.install().await?;
157            }
158        }
159        if !self.is_initialized() {
160            self.initialize().await?;
161        }
162
163        Ok(())
164    }
165
166    /// Install the PostgreSQL server from the archive. If the version minor and/or release are not set,
167    /// the latest version will be determined dynamically during the installation process. If the archive
168    /// hash does not match the expected hash, an error will be returned. If the installation directory
169    /// already exists, the archive will not be extracted. If the archive is not found, an error will be
170    /// returned.
171    #[instrument(skip(self))]
172    async fn install(&mut self) -> Result<()> {
173        debug!(
174            "Starting installation process for version {}",
175            self.settings.version
176        );
177
178        // If the exact version is not set, determine the latest version and update the version and
179        // installation directory accordingly. This is an optimization to avoid downloading the
180        // archive if the latest version is already installed.
181        if self.settings.version.exact_version().is_none() {
182            let version = get_version(&self.settings.releases_url, &self.settings.version).await?;
183            self.settings.version = version.exact_version_req()?;
184            self.settings.installation_dir =
185                self.settings.installation_dir.join(version.to_string());
186        }
187
188        if self.settings.installation_dir.exists() {
189            debug!("Installation directory already exists");
190            return Ok(());
191        }
192
193        let url = &self.settings.releases_url;
194
195        #[cfg(feature = "bundled")]
196        // If the requested version is the same as the version of the bundled archive, use the bundled
197        // archive. This avoids downloading the archive in environments where internet access is
198        // restricted or undesirable.
199        let (version, bytes) = if *crate::settings::ARCHIVE_VERSION == self.settings.version {
200            debug!("Using bundled installation archive");
201            (
202                self.settings.version.clone(),
203                crate::settings::ARCHIVE.to_vec(),
204            )
205        } else {
206            let (version, bytes) = get_archive(url, &self.settings.version).await?;
207            (version.exact_version_req()?, bytes)
208        };
209
210        #[cfg(not(feature = "bundled"))]
211        let (version, bytes) = {
212            let (version, bytes) = get_archive(url, &self.settings.version).await?;
213            (version.exact_version_req()?, bytes)
214        };
215
216        self.settings.version = version;
217        extract(url, &bytes, &self.settings.installation_dir).await?;
218
219        debug!(
220            "Installed PostgreSQL version {} to {}",
221            self.settings.version,
222            self.settings.installation_dir.to_string_lossy()
223        );
224
225        Ok(())
226    }
227
228    /// Initialize the database in the data directory. This will create the necessary files and
229    /// directories to start the database.
230    #[instrument(skip(self))]
231    async fn initialize(&mut self) -> Result<()> {
232        if !self.settings.password_file.exists() {
233            let mut file = std::fs::File::create(&self.settings.password_file)?;
234            file.write_all(self.settings.password.as_bytes())?;
235        }
236
237        debug!(
238            "Initializing database {}",
239            self.settings.data_dir.to_string_lossy()
240        );
241
242        let initdb = InitDbBuilder::from(&self.settings)
243            .pgdata(&self.settings.data_dir)
244            .username(BOOTSTRAP_SUPERUSER)
245            .auth("password")
246            .pwfile(&self.settings.password_file)
247            .encoding("UTF8");
248
249        match self.execute_command(initdb).await {
250            Ok((_stdout, _stderr)) => {
251                debug!(
252                    "Initialized database {}",
253                    self.settings.data_dir.to_string_lossy()
254                );
255                Ok(())
256            }
257            Err(error) => Err(DatabaseInitializationError(error.to_string())),
258        }
259    }
260
261    /// Start the database and wait for the startup to complete.
262    /// If the port is set to `0`, the database will be started on a random port.
263    #[instrument(skip(self))]
264    pub async fn start(&mut self) -> Result<()> {
265        if self.settings.port == 0 {
266            let listener = TcpListener::bind(("0.0.0.0", 0))?;
267            self.settings.port = listener.local_addr()?.port();
268        }
269
270        debug!(
271            "Starting database {} on port {}",
272            self.settings.data_dir.to_string_lossy(),
273            self.settings.port
274        );
275        let start_log = self.settings.data_dir.join("start.log");
276        let mut options = Vec::new();
277        options.push(format!("-F -p {}", self.settings.port));
278        for (key, value) in &self.settings.configuration {
279            options.push(format!("-c {key}={value}"));
280        }
281        let pg_ctl = PgCtlBuilder::from(&self.settings)
282            .env(PGDATABASE, "")
283            .mode(Start)
284            .pgdata(&self.settings.data_dir)
285            .log(start_log)
286            .options(options.as_slice())
287            .wait();
288
289        match self.execute_command(pg_ctl).await {
290            Ok((_stdout, _stderr)) => {
291                debug!(
292                    "Started database {} on port {}",
293                    self.settings.data_dir.to_string_lossy(),
294                    self.settings.port
295                );
296                Ok(())
297            }
298            Err(error) => Err(DatabaseStartError(error.to_string())),
299        }
300    }
301
302    /// Stop the database gracefully (smart mode) and wait for the shutdown to complete.
303    #[instrument(skip(self))]
304    pub async fn stop(&self) -> Result<()> {
305        debug!(
306            "Stopping database {}",
307            self.settings.data_dir.to_string_lossy()
308        );
309        let pg_ctl = PgCtlBuilder::from(&self.settings)
310            .mode(Stop)
311            .pgdata(&self.settings.data_dir)
312            .shutdown_mode(Fast)
313            .wait();
314
315        match self.execute_command(pg_ctl).await {
316            Ok((_stdout, _stderr)) => {
317                debug!(
318                    "Stopped database {}",
319                    self.settings.data_dir.to_string_lossy()
320                );
321                Ok(())
322            }
323            Err(error) => Err(DatabaseStopError(error.to_string())),
324        }
325    }
326
327    /// Get a connection pool to the bootstrap database.
328    async fn get_pool(&self) -> Result<PgPool> {
329        let mut settings = self.settings.clone();
330        settings.username = BOOTSTRAP_SUPERUSER.to_string();
331        let database_url = settings.url(BOOTSTRAP_DATABASE);
332        let pool = PgPool::connect(database_url.as_str()).await?;
333        Ok(pool)
334    }
335
336    /// Create a new database with the given name.
337    #[instrument(skip(self))]
338    pub async fn create_database<S>(&self, database_name: S) -> Result<()>
339    where
340        S: AsRef<str> + std::fmt::Debug,
341    {
342        let database_name = database_name.as_ref();
343        debug!(
344            "Creating database {database_name} for {host}:{port}",
345            host = self.settings.host,
346            port = self.settings.port
347        );
348        let pool = self.get_pool().await?;
349        sqlx::query(format!("CREATE DATABASE \"{database_name}\"").as_str())
350            .execute(&pool)
351            .await
352            .map_err(|error| CreateDatabaseError(error.to_string()))?;
353        pool.close().await;
354        debug!(
355            "Created database {database_name} for {host}:{port}",
356            host = self.settings.host,
357            port = self.settings.port
358        );
359        Ok(())
360    }
361
362    /// Check if a database with the given name exists.
363    #[instrument(skip(self))]
364    pub async fn database_exists<S>(&self, database_name: S) -> Result<bool>
365    where
366        S: AsRef<str> + std::fmt::Debug,
367    {
368        let database_name = database_name.as_ref();
369        debug!(
370            "Checking if database {database_name} exists for {host}:{port}",
371            host = self.settings.host,
372            port = self.settings.port
373        );
374        let pool = self.get_pool().await?;
375        let row = sqlx::query("SELECT COUNT(*) FROM pg_database WHERE datname = $1")
376            .bind(database_name.to_string())
377            .fetch_one(&pool)
378            .await
379            .map_err(|error| DatabaseExistsError(error.to_string()))?;
380        let count: i64 = row.get(0);
381        pool.close().await;
382
383        Ok(count == 1)
384    }
385
386    /// Drop a database with the given name.
387    #[instrument(skip(self))]
388    pub async fn drop_database<S>(&self, database_name: S) -> Result<()>
389    where
390        S: AsRef<str> + std::fmt::Debug,
391    {
392        let database_name = database_name.as_ref();
393        debug!(
394            "Dropping database {database_name} for {host}:{port}",
395            host = self.settings.host,
396            port = self.settings.port
397        );
398        let pool = self.get_pool().await?;
399        sqlx::query(format!("DROP DATABASE IF EXISTS \"{database_name}\"").as_str())
400            .execute(&pool)
401            .await
402            .map_err(|error| DropDatabaseError(error.to_string()))?;
403        pool.close().await;
404        debug!(
405            "Dropped database {database_name} for {host}:{port}",
406            host = self.settings.host,
407            port = self.settings.port
408        );
409        Ok(())
410    }
411
412    #[cfg(not(feature = "tokio"))]
413    /// Execute a command and return the stdout and stderr as strings.
414    #[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
415    async fn execute_command<B: CommandBuilder>(
416        &self,
417        command_builder: B,
418    ) -> postgresql_commands::Result<(String, String)> {
419        let mut command = command_builder.build();
420        command.execute()
421    }
422
423    #[cfg(feature = "tokio")]
424    /// Execute a command and return the stdout and stderr as strings.
425    #[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
426    async fn execute_command<B: CommandBuilder>(
427        &self,
428        command_builder: B,
429    ) -> postgresql_commands::Result<(String, String)> {
430        let mut command = command_builder.build_tokio();
431        command.execute(self.settings.timeout).await
432    }
433}
434
435/// Default `PostgreSQL` server
436impl Default for PostgreSQL {
437    fn default() -> Self {
438        Self::new(Settings::default())
439    }
440}
441
442/// Stop the `PostgreSQL` server and remove the data directory if it is marked as temporary.
443impl Drop for PostgreSQL {
444    fn drop(&mut self) {
445        if self.status() == Status::Started {
446            let mut pg_ctl = PgCtlBuilder::from(&self.settings)
447                .mode(Stop)
448                .pgdata(&self.settings.data_dir)
449                .shutdown_mode(Fast)
450                .wait()
451                .build();
452
453            let _ = pg_ctl.output();
454        }
455
456        if self.settings.temporary {
457            let _ = remove_dir_all(&self.settings.data_dir);
458            let _ = remove_file(&self.settings.password_file);
459        }
460    }
461}