pg_embed/
postgres.rs

1//!
2//! Postgresql server
3//!
4//! Start, stop, initialize the postgresql server.
5//! Create database clusters and databases.
6//!
7use std::io::BufRead;
8use std::path::PathBuf;
9use std::process::Stdio;
10use std::sync::Arc;
11use std::time::Duration;
12
13use futures::TryFutureExt;
14use log::{error, info};
15use tokio::sync::Mutex;
16
17#[cfg(feature = "rt_tokio_migrate")]
18use sqlx_tokio::migrate::{MigrateDatabase, Migrator};
19#[cfg(feature = "rt_tokio_migrate")]
20use sqlx_tokio::postgres::PgPoolOptions;
21#[cfg(feature = "rt_tokio_migrate")]
22use sqlx_tokio::Postgres;
23
24use crate::command_executor::AsyncCommand;
25use crate::pg_access::PgAccess;
26use crate::pg_commands::PgCommand;
27use crate::pg_enums::{PgAuthMethod, PgServerStatus};
28use crate::pg_errors::{PgEmbedError, PgEmbedErrorType};
29use crate::pg_types::PgResult;
30use crate::{pg_fetch, pg_unpack};
31
32///
33/// Database settings
34///
35pub struct PgSettings {
36    /// postgresql database directory
37    pub database_dir: PathBuf,
38    /// postgresql port
39    pub port: u16,
40    /// postgresql user name
41    pub user: String,
42    /// postgresql password
43    pub password: String,
44    /// authentication
45    pub auth_method: PgAuthMethod,
46    /// persist database
47    pub persistent: bool,
48    /// duration to wait before terminating process execution
49    /// pg_ctl start/stop and initdb timeout
50    pub timeout: Option<Duration>,
51    /// migrations folder
52    /// sql script files to execute on migrate
53    pub migration_dir: Option<PathBuf>,
54}
55
56///
57/// Embedded postgresql database
58///
59/// If the PgEmbed instance is dropped / goes out of scope and postgresql is still
60/// running, the postgresql process will be killed and depending on the [PgSettings::persistent] setting,
61/// file and directories will be cleaned up.
62///
63pub struct PgEmbed {
64    /// Postgresql settings
65    pub pg_settings: PgSettings,
66    /// Download settings
67    pub fetch_settings: pg_fetch::PgFetchSettings,
68    /// Database uri `postgres://{username}:{password}@localhost:{port}`
69    pub db_uri: String,
70    /// Postgres server status
71    pub server_status: Arc<Mutex<PgServerStatus>>,
72    pub shutting_down: bool,
73    /// Postgres files access
74    pub pg_access: PgAccess,
75}
76
77impl Drop for PgEmbed {
78    fn drop(&mut self) {
79        if !self.shutting_down {
80            let _ = self.stop_db_sync();
81        }
82        if !&self.pg_settings.persistent {
83            let _ = &self.pg_access.clean();
84        }
85    }
86}
87
88impl PgEmbed {
89    ///
90    /// Create a new PgEmbed instance
91    ///
92    pub async fn new(
93        pg_settings: PgSettings,
94        fetch_settings: pg_fetch::PgFetchSettings,
95    ) -> PgResult<Self> {
96        let password: &str = &pg_settings.password;
97        let db_uri = format!(
98            "postgres://{}:{}@localhost:{}",
99            &pg_settings.user,
100            &password,
101            pg_settings.port.to_string()
102        );
103        let pg_access = PgAccess::new(&fetch_settings, &pg_settings.database_dir).await?;
104        Ok(PgEmbed {
105            pg_settings,
106            fetch_settings,
107            db_uri,
108            server_status: Arc::new(Mutex::new(PgServerStatus::Uninitialized)),
109            shutting_down: false,
110            pg_access,
111        })
112    }
113
114    ///
115    /// Setup postgresql for execution
116    ///
117    /// Download, unpack, create password file and database
118    ///
119    pub async fn setup(&mut self) -> PgResult<()> {
120        if self.pg_access.acquisition_needed().await? {
121            self.acquire_postgres().await?;
122        }
123        self.pg_access
124            .create_password_file(self.pg_settings.password.as_bytes())
125            .await?;
126        if self.pg_access.db_files_exist().await? {
127            let mut server_status = self.server_status.lock().await;
128            *server_status = PgServerStatus::Initialized;
129        } else {
130            &self.init_db().await?;
131        }
132        Ok(())
133    }
134
135    ///
136    /// Download and unpack postgres binaries
137    ///
138    pub async fn acquire_postgres(&self) -> PgResult<()> {
139        self.pg_access.mark_acquisition_in_progress().await?;
140        let pg_bin_data = &self.fetch_settings.fetch_postgres().await?;
141        self.pg_access.write_pg_zip(&pg_bin_data).await?;
142        pg_unpack::unpack_postgres(&self.pg_access.zip_file_path, &self.pg_access.cache_dir)
143            .await?;
144        self.pg_access.mark_acquisition_finished().await?;
145        Ok(())
146    }
147
148    ///
149    /// Initialize postgresql database
150    ///
151    /// Returns `Ok(())` on success, otherwise returns an error.
152    ///
153    pub async fn init_db(&mut self) -> PgResult<()> {
154        {
155            let mut server_status = self.server_status.lock().await;
156            *server_status = PgServerStatus::Initializing;
157        }
158
159        let mut executor = PgCommand::init_db_executor(
160            &self.pg_access.init_db_exe,
161            &self.pg_access.database_dir,
162            &self.pg_access.pw_file_path,
163            &self.pg_settings.user,
164            &self.pg_settings.auth_method,
165        )?;
166        let exit_status = executor.execute(self.pg_settings.timeout).await?;
167        let mut server_status = self.server_status.lock().await;
168        *server_status = exit_status;
169        Ok(())
170    }
171
172    ///
173    /// Start postgresql database
174    ///
175    /// Returns `Ok(())` on success, otherwise returns an error.
176    ///
177    pub async fn start_db(&mut self) -> PgResult<()> {
178        {
179            let mut server_status = self.server_status.lock().await;
180            *server_status = PgServerStatus::Starting;
181        }
182        self.shutting_down = false;
183        let mut executor = PgCommand::start_db_executor(
184            &self.pg_access.pg_ctl_exe,
185            &self.pg_access.database_dir,
186            &self.pg_settings.port,
187        )?;
188        let exit_status = executor.execute(self.pg_settings.timeout).await?;
189        let mut server_status = self.server_status.lock().await;
190        *server_status = exit_status;
191        Ok(())
192    }
193
194    ///
195    /// Stop postgresql database
196    ///
197    /// Returns `Ok(())` on success, otherwise returns an error.
198    ///
199    pub async fn stop_db(&mut self) -> PgResult<()> {
200        {
201            let mut server_status = self.server_status.lock().await;
202            *server_status = PgServerStatus::Stopping;
203        }
204        self.shutting_down = true;
205        let mut executor =
206            PgCommand::stop_db_executor(&self.pg_access.pg_ctl_exe, &self.pg_access.database_dir)?;
207        let exit_status = executor.execute(self.pg_settings.timeout).await?;
208        let mut server_status = self.server_status.lock().await;
209        *server_status = exit_status;
210        Ok(())
211    }
212
213    ///
214    /// Stop postgresql database synchronous
215    ///
216    /// Returns `Ok(())` on success, otherwise returns an error.
217    ///
218    pub fn stop_db_sync(&mut self) -> PgResult<()> {
219        self.shutting_down = true;
220        let mut stop_db_command = self
221            .pg_access
222            .stop_db_command_sync(&self.pg_settings.database_dir);
223        let process = stop_db_command
224            .get_mut()
225            .stdout(Stdio::piped())
226            .stderr(Stdio::piped())
227            .spawn()
228            .map_err(|e| PgEmbedError {
229                error_type: PgEmbedErrorType::PgError,
230                source: Some(Box::new(e)),
231                message: None,
232            })?;
233
234        self.handle_process_io_sync(process)
235    }
236
237    ///
238    /// Handle process logging synchronous
239    ///
240    pub fn handle_process_io_sync(&self, mut process: std::process::Child) -> PgResult<()> {
241        let reader_out = std::io::BufReader::new(process.stdout.take().unwrap()).lines();
242        let reader_err = std::io::BufReader::new(process.stderr.take().unwrap()).lines();
243        reader_out.for_each(|line| info!("{}", line.unwrap()));
244        reader_err.for_each(|line| error!("{}", line.unwrap()));
245        Ok(())
246    }
247
248    ///
249    /// Create a database
250    ///
251    #[cfg(any(
252        feature = "rt_tokio_migrate",
253        feature = "rt_async_std_migrate",
254        feature = "rt_actix_migrate"
255    ))]
256    pub async fn create_database(&self, db_name: &str) -> PgResult<()> {
257        Postgres::create_database(&self.full_db_uri(db_name))
258            .map_err(|e| PgEmbedError {
259                error_type: PgEmbedErrorType::PgTaskJoinError,
260                source: Some(Box::new(e)),
261                message: None,
262            })
263            .await?;
264        Ok(())
265    }
266
267    ///
268    /// Drop a database
269    ///
270    #[cfg(any(
271        feature = "rt_tokio_migrate",
272        feature = "rt_async_std_migrate",
273        feature = "rt_actix_migrate"
274    ))]
275    pub async fn drop_database(&self, db_name: &str) -> PgResult<()> {
276        Postgres::drop_database(&self.full_db_uri(db_name))
277            .map_err(|e| PgEmbedError {
278                error_type: PgEmbedErrorType::PgTaskJoinError,
279                source: Some(Box::new(e)),
280                message: None,
281            })
282            .await?;
283        Ok(())
284    }
285
286    ///
287    /// Check database existence
288    ///
289    #[cfg(any(
290        feature = "rt_tokio_migrate",
291        feature = "rt_async_std_migrate",
292        feature = "rt_actix_migrate"
293    ))]
294    pub async fn database_exists(&self, db_name: &str) -> PgResult<bool> {
295        let result = Postgres::database_exists(&self.full_db_uri(db_name))
296            .map_err(|e| PgEmbedError {
297                error_type: PgEmbedErrorType::PgTaskJoinError,
298                source: Some(Box::new(e)),
299                message: None,
300            })
301            .await?;
302        Ok(result)
303    }
304
305    ///
306    /// The full database uri
307    ///
308    /// (*postgres://{username}:{password}@localhost:{port}/{db_name}*)
309    ///
310    pub fn full_db_uri(&self, db_name: &str) -> String {
311        format!("{}/{}", &self.db_uri, db_name)
312    }
313
314    ///
315    /// Run migrations
316    ///
317    #[cfg(any(
318        feature = "rt_tokio_migrate",
319        feature = "rt_async_std_migrate",
320        feature = "rt_actix_migrate"
321    ))]
322    pub async fn migrate(&self, db_name: &str) -> PgResult<()> {
323        if let Some(migration_dir) = &self.pg_settings.migration_dir {
324            let m = Migrator::new(migration_dir.as_path())
325                .map_err(|e| PgEmbedError {
326                    error_type: PgEmbedErrorType::MigrationError,
327                    source: Some(Box::new(e)),
328                    message: None,
329                })
330                .await?;
331            let pool = PgPoolOptions::new()
332                .connect(&self.full_db_uri(db_name))
333                .map_err(|e| PgEmbedError {
334                    error_type: PgEmbedErrorType::SqlQueryError,
335                    source: Some(Box::new(e)),
336                    message: None,
337                })
338                .await?;
339            m.run(&pool)
340                .map_err(|e| PgEmbedError {
341                    error_type: PgEmbedErrorType::MigrationError,
342                    source: Some(Box::new(e)),
343                    message: None,
344                })
345                .await?;
346        }
347        Ok(())
348    }
349}