Skip to main content

pg_embed/
pg_access.rs

1//! File-system access layer for cached PostgreSQL binaries and database clusters.
2//!
3//! [`PgAccess`] encapsulates all paths used by pg-embed (cache dir, database
4//! dir, executable paths, password file) and provides the operations that act
5//! on those paths: downloading, unpacking, writing the password file, and
6//! cleaning up.
7//!
8//! The module-level static `ACQUIRED_PG_BINS` prevents concurrent downloads
9//! of the same binaries when multiple [`crate::postgres::PgEmbed`] instances
10//! start simultaneously.
11
12use std::cell::Cell;
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, LazyLock};
16
17use tokio::io::AsyncWriteExt;
18use tokio::sync::Mutex;
19
20use crate::pg_enums::{OperationSystem, PgAcquisitionStatus};
21use crate::pg_errors::Error;
22use crate::pg_fetch::PgFetchSettings;
23use crate::pg_types::PgCommandSync;
24use crate::pg_unpack;
25use crate::pg_errors::Result;
26
27/// Guards concurrent binary downloads across multiple [`crate::postgres::PgEmbed`] instances.
28///
29/// The key is the cache directory path; the value tracks whether acquisition
30/// is in progress or finished.  Protected by a [`Mutex`] to allow only one
31/// download per unique cache path at a time.
32static ACQUIRED_PG_BINS: LazyLock<Arc<Mutex<HashMap<PathBuf, PgAcquisitionStatus>>>> =
33    LazyLock::new(|| Arc::new(Mutex::new(HashMap::with_capacity(5))));
34
35const PG_EMBED_CACHE_DIR_NAME: &str = "pg-embed";
36const PG_VERSION_FILE_NAME: &str = "PG_VERSION";
37
38/// Manages all file-system paths and I/O operations for a single pg-embed instance.
39///
40/// Created by [`PgAccess::new`], which also creates the required directory
41/// structure.  All path fields are derived from the fetch settings and the
42/// caller-supplied database directory.
43///
44/// # Cache layout
45///
46/// ```text
47/// {cache_dir}/pg-embed/{os}/{arch}/{version}/
48///   bin/pg_ctl
49///   bin/initdb
50///   {platform}-{version}.zip
51/// ```
52pub struct PgAccess {
53    /// Root of the per-version binary cache.
54    pub cache_dir: PathBuf,
55    /// Directory that holds the PostgreSQL cluster data files.
56    pub database_dir: PathBuf,
57    /// Path to the `pg_ctl` executable inside the cache.
58    pub pg_ctl_exe: PathBuf,
59    /// Path to the `initdb` executable inside the cache.
60    pub init_db_exe: PathBuf,
61    /// Path to the password file used by `initdb`.
62    pub pw_file_path: PathBuf,
63    /// Path where the downloaded JAR is written before unpacking.
64    pub zip_file_path: PathBuf,
65    /// `PG_VERSION` file inside the cluster directory; used to detect an
66    /// already-initialised cluster.
67    pg_version_file: PathBuf,
68    /// Download settings used to reconstruct the cache path.
69    fetch_settings: PgFetchSettings,
70}
71
72impl PgAccess {
73    /// Creates a new [`PgAccess`] and ensures the required directories exist.
74    ///
75    /// Both the per-version binary cache directory and `database_dir` are
76    /// created with [`tokio::fs::create_dir_all`] if they do not already exist.
77    ///
78    /// # Arguments
79    ///
80    /// * `fetch_settings` — Determines the OS, architecture, and version used
81    ///   to construct the cache path.
82    /// * `database_dir` — Where the PostgreSQL cluster data files will live.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`Error::InvalidPgUrl`] if the OS cache directory cannot be
87    /// resolved.
88    /// Returns [`Error::DirCreationError`] if either directory cannot be
89    /// created.
90    pub async fn new(
91        fetch_settings: &PgFetchSettings,
92        database_dir: &Path,
93    ) -> Result<Self> {
94        let cache_dir = Self::create_cache_dir_structure(fetch_settings).await?;
95        Self::create_db_dir_structure(database_dir).await?;
96        let platform = fetch_settings.platform();
97        let pg_ctl = cache_dir.join("bin/pg_ctl");
98        let init_db = cache_dir.join("bin/initdb");
99        let zip_file_path = cache_dir.join(format!("{}-{}.zip", platform, fetch_settings.version.0));
100        let mut pw_file = database_dir.to_path_buf();
101        pw_file.set_extension("pwfile");
102        let pg_version_file = database_dir.join(PG_VERSION_FILE_NAME);
103
104        Ok(PgAccess {
105            cache_dir,
106            database_dir: database_dir.to_path_buf(),
107            pg_ctl_exe: pg_ctl,
108            init_db_exe: init_db,
109            pw_file_path: pw_file,
110            zip_file_path,
111            pg_version_file,
112            fetch_settings: fetch_settings.clone(),
113        })
114    }
115
116    /// Creates the OS-specific cache directory tree for this OS/arch/version.
117    ///
118    /// # Errors
119    ///
120    /// Returns [`Error::InvalidPgUrl`] if the OS cache directory cannot be
121    /// resolved.
122    /// Returns [`Error::DirCreationError`] if the directory cannot be created.
123    async fn create_cache_dir_structure(fetch_settings: &PgFetchSettings) -> Result<PathBuf> {
124        let cache_dir = dirs::cache_dir().ok_or(Error::InvalidPgUrl)?;
125        let os_string = match fetch_settings.operating_system {
126            OperationSystem::Darwin | OperationSystem::Windows | OperationSystem::Linux => {
127                fetch_settings.operating_system.to_string()
128            }
129            OperationSystem::AlpineLinux => {
130                format!("arch_{}", fetch_settings.operating_system)
131            }
132        };
133        let pg_path = format!(
134            "{}/{}/{}/{}",
135            PG_EMBED_CACHE_DIR_NAME,
136            os_string,
137            fetch_settings.architecture,
138            fetch_settings.version.0
139        );
140        let mut cache_pg_embed = cache_dir;
141        cache_pg_embed.push(pg_path);
142        tokio::fs::create_dir_all(&cache_pg_embed)
143            .await
144            .map_err(|e| Error::DirCreationError(e.to_string()))?;
145        Ok(cache_pg_embed)
146    }
147
148    /// Creates the database cluster directory.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`Error::DirCreationError`] if the directory cannot be created.
153    async fn create_db_dir_structure(db_dir: &Path) -> Result<()> {
154        tokio::fs::create_dir_all(db_dir)
155            .await
156            .map_err(|e| Error::DirCreationError(e.to_string()))
157    }
158
159    /// Downloads and unpacks the PostgreSQL binaries if they are not already cached.
160    ///
161    /// Acquires the `ACQUIRED_PG_BINS` lock for the duration.  If another
162    /// instance already cached the binaries (i.e. [`Self::pg_executables_cached`]
163    /// returns `true`), this method returns immediately without downloading.
164    ///
165    /// # Errors
166    ///
167    /// Returns [`Error::DirCreationError`] if directories cannot be created.
168    /// Returns [`Error::DownloadFailure`] or [`Error::ConversionFailure`] if
169    /// the HTTP download fails.
170    /// Returns [`Error::WriteFileError`] if the JAR cannot be written to disk.
171    /// Returns [`Error::UnpackFailure`] or [`Error::InvalidPgPackage`] if
172    /// extraction fails.
173    pub async fn maybe_acquire_postgres(&self) -> Result<()> {
174        let mut lock = ACQUIRED_PG_BINS.lock().await;
175
176        if self.pg_executables_cached().await? {
177            return Ok(());
178        }
179
180        lock.insert(self.cache_dir.clone(), PgAcquisitionStatus::InProgress);
181        self.fetch_settings
182            .fetch_postgres_to_file(&self.zip_file_path)
183            .await?;
184        log::debug!(
185            "Unpacking postgres binaries {} {}",
186            self.zip_file_path.display(),
187            self.cache_dir.display()
188        );
189        pg_unpack::unpack_postgres(&self.zip_file_path, &self.cache_dir).await?;
190
191        if let Some(status) = lock.get_mut(&self.cache_dir) {
192            *status = PgAcquisitionStatus::Finished;
193        }
194        Ok(())
195    }
196
197    /// Returns `true` if the `initdb` executable is present in the cache.
198    ///
199    /// # Errors
200    ///
201    /// Returns [`Error::ReadFileError`] if the filesystem existence check fails.
202    pub async fn pg_executables_cached(&self) -> Result<bool> {
203        Self::path_exists(self.init_db_exe.as_path()).await
204    }
205
206    /// Returns `true` if both the executables and the cluster version file exist.
207    ///
208    /// A `true` result indicates the cluster was previously initialised with
209    /// `initdb` and does not need to be re-initialised.
210    ///
211    /// # Errors
212    ///
213    /// Returns [`Error::ReadFileError`] if either filesystem check fails.
214    pub async fn db_files_exist(&self) -> Result<bool> {
215        Ok(self.pg_executables_cached().await?
216            && Self::path_exists(self.pg_version_file.as_path()).await?)
217    }
218
219    /// Returns `true` if the `PG_VERSION` file exists inside `db_dir`.
220    ///
221    /// Useful for confirming that a cluster directory is non-empty without
222    /// holding a [`PgAccess`] instance.
223    ///
224    /// # Arguments
225    ///
226    /// * `db_dir` — The cluster data directory to inspect.
227    ///
228    /// # Errors
229    ///
230    /// Returns [`Error::ReadFileError`] if the filesystem check fails.
231    pub async fn pg_version_file_exists(db_dir: &Path) -> Result<bool> {
232        let pg_version_file = db_dir.join(PG_VERSION_FILE_NAME);
233        Self::path_exists(&pg_version_file).await
234    }
235
236    /// Returns `true` if `file` exists on the filesystem.
237    ///
238    /// Uses [`tokio::fs::try_exists`] which returns `false` (not an error) for
239    /// permission-denied on the file itself; see its documentation for edge
240    /// cases.
241    ///
242    /// # Errors
243    ///
244    /// Returns [`Error::ReadFileError`] if the syscall itself fails (e.g.
245    /// the parent directory is inaccessible).
246    async fn path_exists(file: &Path) -> Result<bool> {
247        tokio::fs::try_exists(file)
248            .await
249            .map_err(|e| Error::ReadFileError(e.to_string()))
250    }
251
252    /// Returns the current acquisition status for this instance's cache directory.
253    pub async fn acquisition_status(&self) -> PgAcquisitionStatus {
254        let lock = ACQUIRED_PG_BINS.lock().await;
255        let acquisition_status = lock.get(&self.cache_dir);
256        match acquisition_status {
257            None => PgAcquisitionStatus::Undefined,
258            Some(status) => *status,
259        }
260    }
261
262    /// Removes the database cluster directory and the password file.
263    ///
264    /// Both removals are attempted even if the first one fails; the first
265    /// error encountered is returned.  Called synchronously from
266    /// [`crate::postgres::PgEmbed`]'s `Drop` implementation.
267    ///
268    /// # Errors
269    ///
270    /// Returns [`Error::PgCleanUpFailure`] if either removal fails.
271    pub fn clean(&self) -> Result<()> {
272        let dir_result = std::fs::remove_dir_all(&self.database_dir)
273            .map_err(|e| Error::PgCleanUpFailure(e.to_string()));
274        let file_result = std::fs::remove_file(&self.pw_file_path)
275            .map_err(|e| Error::PgCleanUpFailure(e.to_string()));
276        // Both operations run before returning the first error (if any)
277        dir_result.and(file_result)
278    }
279
280    /// Removes the entire `pg-embed` binary cache directory.
281    ///
282    /// Useful for freeing disk space or forcing a fresh download.  Errors
283    /// during removal are silently ignored (the function always returns `Ok`).
284    ///
285    /// # Errors
286    ///
287    /// Returns [`Error::ReadFileError`] if the OS cache directory cannot be
288    /// resolved.
289    pub async fn purge() -> Result<()> {
290        let mut cache_dir = dirs::cache_dir()
291            .ok_or_else(|| Error::ReadFileError("cache dir not found".into()))?;
292        cache_dir.push(PG_EMBED_CACHE_DIR_NAME);
293        let _ = tokio::fs::remove_dir_all(&cache_dir).await;
294        Ok(())
295    }
296
297    /// Removes `database_dir` and `pw_file` asynchronously.
298    ///
299    /// Unlike [`Self::clean`], this is an `async` free-standing helper and
300    /// stops on the first error.
301    ///
302    /// # Arguments
303    ///
304    /// * `database_dir` — The cluster data directory to remove.
305    /// * `pw_file` — The password file to remove.
306    ///
307    /// # Errors
308    ///
309    /// Returns [`Error::PgCleanUpFailure`] if either removal fails.
310    pub async fn clean_up(database_dir: PathBuf, pw_file: PathBuf) -> Result<()> {
311        tokio::fs::remove_dir_all(&database_dir)
312            .await
313            .map_err(|e| Error::PgCleanUpFailure(e.to_string()))?;
314
315        tokio::fs::remove_file(&pw_file)
316            .await
317            .map_err(|e| Error::PgCleanUpFailure(e.to_string()))
318    }
319
320    /// Writes `password` bytes to [`Self::pw_file_path`].
321    ///
322    /// `initdb` reads this file via `--pwfile` to set the superuser password
323    /// without exposing it on the command line.
324    ///
325    /// # Arguments
326    ///
327    /// * `password` — The password bytes to write (UTF-8 text is expected but
328    ///   not enforced).
329    ///
330    /// # Errors
331    ///
332    /// Returns [`Error::WriteFileError`] if the file cannot be created or the
333    /// write fails.
334    pub async fn create_password_file(&self, password: &[u8]) -> Result<()> {
335        let mut file = tokio::fs::File::create(self.pw_file_path.as_path())
336            .await
337            .map_err(|e| Error::WriteFileError(e.to_string()))?;
338        file.write_all(password)
339            .await
340            .map_err(|e| Error::WriteFileError(e.to_string()))
341    }
342
343    /// Installs a third-party extension into the binary cache.
344    ///
345    /// Copies files from `extension_dir` into the appropriate subdirectory of
346    /// [`Self::cache_dir`]:
347    ///
348    /// | Source extension | Destination |
349    /// |---|---|
350    /// | `.so`, `.dylib`, `.dll` | `{cache_dir}/lib/` |
351    /// | `.control`, `.sql` | `{cache_dir}/share/postgresql/extension/` (or equivalent) |
352    /// | anything else, subdirectories | silently skipped |
353    ///
354    /// Call this method after [`crate::postgres::PgEmbed::setup`] and before
355    /// [`crate::postgres::PgEmbed::start_db`], then run
356    /// `CREATE EXTENSION IF NOT EXISTS <name>` once the server is up.
357    ///
358    /// # Arguments
359    ///
360    /// * `extension_dir` — Directory containing the extension files to install.
361    ///
362    /// # Errors
363    ///
364    /// Returns [`Error::DirCreationError`] if the target directories cannot be
365    /// created.
366    /// Returns [`Error::ReadFileError`] if `extension_dir` cannot be read or a
367    /// directory entry cannot be inspected.
368    /// Returns [`Error::WriteFileError`] if a file cannot be copied.
369    /// Returns the path of the `extension/` directory inside the binary cache.
370    ///
371    /// Searches for an existing `extension/` subdirectory under `share/` in the
372    /// cache (trying common PostgreSQL layout variants).  Falls back to
373    /// `share/postgresql/extension` — the standard location used by the
374    /// zonkyio binaries — when none of the candidates exist yet.
375    async fn share_extension_dir(cache_dir: &Path) -> PathBuf {
376        let candidates = [
377            cache_dir.join("share/postgresql/extension"),
378            cache_dir.join("share/extension"),
379        ];
380        for candidate in &candidates {
381            if tokio::fs::try_exists(candidate).await.unwrap_or(false) {
382                return candidate.clone();
383            }
384        }
385        candidates[0].clone()
386    }
387
388    pub async fn install_extension(&self, extension_dir: &Path) -> Result<()> {
389        let lib_dir = self.cache_dir.join("lib");
390        let share_ext_dir = Self::share_extension_dir(&self.cache_dir).await;
391
392        tokio::fs::create_dir_all(&lib_dir)
393            .await
394            .map_err(|e| Error::DirCreationError(e.to_string()))?;
395        tokio::fs::create_dir_all(&share_ext_dir)
396            .await
397            .map_err(|e| Error::DirCreationError(e.to_string()))?;
398
399        let mut entries = tokio::fs::read_dir(extension_dir)
400            .await
401            .map_err(|e| Error::ReadFileError(e.to_string()))?;
402
403        while let Some(entry) = entries
404            .next_entry()
405            .await
406            .map_err(|e| Error::ReadFileError(e.to_string()))?
407        {
408            let file_type = entry
409                .file_type()
410                .await
411                .map_err(|e| Error::ReadFileError(e.to_string()))?;
412            if !file_type.is_file() {
413                continue;
414            }
415
416            let path = entry.path();
417            let file_name = match path.file_name() {
418                Some(n) => n,
419                None => continue,
420            };
421            let dest_dir = match path.extension().and_then(|e| e.to_str()) {
422                Some("so") | Some("dylib") | Some("dll") => &lib_dir,
423                Some("control") | Some("sql") => &share_ext_dir,
424                _ => continue,
425            };
426            tokio::fs::copy(&path, dest_dir.join(file_name))
427                .await
428                .map_err(|e| Error::WriteFileError(e.to_string()))?;
429        }
430        Ok(())
431    }
432
433    /// Builds a synchronous `pg_ctl stop` [`std::process::Command`].
434    ///
435    /// Uses [`OsStr`][std::ffi::OsStr] arguments throughout to avoid UTF-8
436    /// conversion failures on platforms with non-Unicode paths.  The returned
437    /// [`PgCommandSync`] is ready to be spawned but has not yet been started.
438    ///
439    /// # Arguments
440    ///
441    /// * `database_dir` — Passed as the `-D` argument to `pg_ctl stop`.
442    pub fn stop_db_command_sync(&self, database_dir: &Path) -> PgCommandSync {
443        let mut command = Box::new(Cell::new(
444            std::process::Command::new(self.pg_ctl_exe.as_os_str()),
445        ));
446        command.get_mut().arg("stop").arg("-w").arg("-D").arg(database_dir);
447        command
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::pg_fetch::{PgFetchSettings, PG_V17};
455
456    #[tokio::test]
457    async fn test_install_extension() {
458        let src_dir = tempfile::TempDir::new().unwrap();
459        let src_path = src_dir.path();
460
461        std::fs::write(src_path.join("myvec.so"), b"fake so").unwrap();
462        std::fs::write(src_path.join("myvec.dylib"), b"fake dylib").unwrap();
463        std::fs::write(src_path.join("myvec.control"), b"# control").unwrap();
464        std::fs::write(src_path.join("myvec--1.0.sql"), b"-- sql").unwrap();
465        std::fs::write(src_path.join("README.txt"), b"readme").unwrap();
466
467        let cache_dir = tempfile::TempDir::new().unwrap();
468        let cache_path = cache_dir.path().to_path_buf();
469
470        let pg_access = PgAccess {
471            cache_dir: cache_path.clone(),
472            database_dir: cache_path.join("db"),
473            pg_ctl_exe: cache_path.join("bin/pg_ctl"),
474            init_db_exe: cache_path.join("bin/initdb"),
475            pw_file_path: cache_path.join("db.pwfile"),
476            zip_file_path: cache_path.join("pg.zip"),
477            pg_version_file: cache_path.join("db/PG_VERSION"),
478            fetch_settings: PgFetchSettings {
479                version: PG_V17,
480                ..Default::default()
481            },
482        };
483
484        pg_access.install_extension(src_path).await.unwrap();
485
486        assert!(cache_path.join("lib/myvec.so").exists(), "lib/myvec.so missing");
487        assert!(cache_path.join("lib/myvec.dylib").exists(), "lib/myvec.dylib missing");
488        // No existing share dir → falls back to share/postgresql/extension
489        assert!(
490            cache_path.join("share/postgresql/extension/myvec.control").exists(),
491            "share/postgresql/extension/myvec.control missing"
492        );
493        assert!(
494            cache_path.join("share/postgresql/extension/myvec--1.0.sql").exists(),
495            "share/postgresql/extension/myvec--1.0.sql missing"
496        );
497        assert!(
498            !cache_path.join("lib/README.txt").exists(),
499            "README.txt should not be in lib/"
500        );
501        assert!(
502            !cache_path.join("share/postgresql/extension/README.txt").exists(),
503            "README.txt should not be in share/postgresql/extension/"
504        );
505    }
506}