1use std::io::BufRead;
8use std::path::PathBuf;
9use std::process::Stdio;
10use std::sync::Arc;
11use std::time::Duration;
12
13use futures::TryFutureExt;
14use log::{error, info};
15use tokio::sync::Mutex;
16
17#[cfg(feature = "rt_tokio_migrate")]
18use sqlx_tokio::migrate::{MigrateDatabase, Migrator};
19#[cfg(feature = "rt_tokio_migrate")]
20use sqlx_tokio::postgres::PgPoolOptions;
21#[cfg(feature = "rt_tokio_migrate")]
22use sqlx_tokio::Postgres;
23
24use crate::command_executor::AsyncCommand;
25use crate::pg_access::PgAccess;
26use crate::pg_commands::PgCommand;
27use crate::pg_enums::{PgAuthMethod, PgServerStatus};
28use crate::pg_errors::{PgEmbedError, PgEmbedErrorType};
29use crate::pg_types::PgResult;
30use crate::{pg_fetch, pg_unpack};
31
32pub struct PgSettings {
36 pub database_dir: PathBuf,
38 pub port: u16,
40 pub user: String,
42 pub password: String,
44 pub auth_method: PgAuthMethod,
46 pub persistent: bool,
48 pub timeout: Option<Duration>,
51 pub migration_dir: Option<PathBuf>,
54}
55
56pub struct PgEmbed {
64 pub pg_settings: PgSettings,
66 pub fetch_settings: pg_fetch::PgFetchSettings,
68 pub db_uri: String,
70 pub server_status: Arc<Mutex<PgServerStatus>>,
72 pub shutting_down: bool,
73 pub pg_access: PgAccess,
75}
76
77impl Drop for PgEmbed {
78 fn drop(&mut self) {
79 if !self.shutting_down {
80 let _ = self.stop_db_sync();
81 }
82 if !&self.pg_settings.persistent {
83 let _ = &self.pg_access.clean();
84 }
85 }
86}
87
88impl PgEmbed {
89 pub async fn new(
93 pg_settings: PgSettings,
94 fetch_settings: pg_fetch::PgFetchSettings,
95 ) -> PgResult<Self> {
96 let password: &str = &pg_settings.password;
97 let db_uri = format!(
98 "postgres://{}:{}@localhost:{}",
99 &pg_settings.user,
100 &password,
101 pg_settings.port.to_string()
102 );
103 let pg_access = PgAccess::new(&fetch_settings, &pg_settings.database_dir).await?;
104 Ok(PgEmbed {
105 pg_settings,
106 fetch_settings,
107 db_uri,
108 server_status: Arc::new(Mutex::new(PgServerStatus::Uninitialized)),
109 shutting_down: false,
110 pg_access,
111 })
112 }
113
114 pub async fn setup(&mut self) -> PgResult<()> {
120 if self.pg_access.acquisition_needed().await? {
121 self.acquire_postgres().await?;
122 }
123 self.pg_access
124 .create_password_file(self.pg_settings.password.as_bytes())
125 .await?;
126 if self.pg_access.db_files_exist().await? {
127 let mut server_status = self.server_status.lock().await;
128 *server_status = PgServerStatus::Initialized;
129 } else {
130 &self.init_db().await?;
131 }
132 Ok(())
133 }
134
135 pub async fn acquire_postgres(&self) -> PgResult<()> {
139 self.pg_access.mark_acquisition_in_progress().await?;
140 let pg_bin_data = &self.fetch_settings.fetch_postgres().await?;
141 self.pg_access.write_pg_zip(&pg_bin_data).await?;
142 pg_unpack::unpack_postgres(&self.pg_access.zip_file_path, &self.pg_access.cache_dir)
143 .await?;
144 self.pg_access.mark_acquisition_finished().await?;
145 Ok(())
146 }
147
148 pub async fn init_db(&mut self) -> PgResult<()> {
154 {
155 let mut server_status = self.server_status.lock().await;
156 *server_status = PgServerStatus::Initializing;
157 }
158
159 let mut executor = PgCommand::init_db_executor(
160 &self.pg_access.init_db_exe,
161 &self.pg_access.database_dir,
162 &self.pg_access.pw_file_path,
163 &self.pg_settings.user,
164 &self.pg_settings.auth_method,
165 )?;
166 let exit_status = executor.execute(self.pg_settings.timeout).await?;
167 let mut server_status = self.server_status.lock().await;
168 *server_status = exit_status;
169 Ok(())
170 }
171
172 pub async fn start_db(&mut self) -> PgResult<()> {
178 {
179 let mut server_status = self.server_status.lock().await;
180 *server_status = PgServerStatus::Starting;
181 }
182 self.shutting_down = false;
183 let mut executor = PgCommand::start_db_executor(
184 &self.pg_access.pg_ctl_exe,
185 &self.pg_access.database_dir,
186 &self.pg_settings.port,
187 )?;
188 let exit_status = executor.execute(self.pg_settings.timeout).await?;
189 let mut server_status = self.server_status.lock().await;
190 *server_status = exit_status;
191 Ok(())
192 }
193
194 pub async fn stop_db(&mut self) -> PgResult<()> {
200 {
201 let mut server_status = self.server_status.lock().await;
202 *server_status = PgServerStatus::Stopping;
203 }
204 self.shutting_down = true;
205 let mut executor =
206 PgCommand::stop_db_executor(&self.pg_access.pg_ctl_exe, &self.pg_access.database_dir)?;
207 let exit_status = executor.execute(self.pg_settings.timeout).await?;
208 let mut server_status = self.server_status.lock().await;
209 *server_status = exit_status;
210 Ok(())
211 }
212
213 pub fn stop_db_sync(&mut self) -> PgResult<()> {
219 self.shutting_down = true;
220 let mut stop_db_command = self
221 .pg_access
222 .stop_db_command_sync(&self.pg_settings.database_dir);
223 let process = stop_db_command
224 .get_mut()
225 .stdout(Stdio::piped())
226 .stderr(Stdio::piped())
227 .spawn()
228 .map_err(|e| PgEmbedError {
229 error_type: PgEmbedErrorType::PgError,
230 source: Some(Box::new(e)),
231 message: None,
232 })?;
233
234 self.handle_process_io_sync(process)
235 }
236
237 pub fn handle_process_io_sync(&self, mut process: std::process::Child) -> PgResult<()> {
241 let reader_out = std::io::BufReader::new(process.stdout.take().unwrap()).lines();
242 let reader_err = std::io::BufReader::new(process.stderr.take().unwrap()).lines();
243 reader_out.for_each(|line| info!("{}", line.unwrap()));
244 reader_err.for_each(|line| error!("{}", line.unwrap()));
245 Ok(())
246 }
247
248 #[cfg(any(
252 feature = "rt_tokio_migrate",
253 feature = "rt_async_std_migrate",
254 feature = "rt_actix_migrate"
255 ))]
256 pub async fn create_database(&self, db_name: &str) -> PgResult<()> {
257 Postgres::create_database(&self.full_db_uri(db_name))
258 .map_err(|e| PgEmbedError {
259 error_type: PgEmbedErrorType::PgTaskJoinError,
260 source: Some(Box::new(e)),
261 message: None,
262 })
263 .await?;
264 Ok(())
265 }
266
267 #[cfg(any(
271 feature = "rt_tokio_migrate",
272 feature = "rt_async_std_migrate",
273 feature = "rt_actix_migrate"
274 ))]
275 pub async fn drop_database(&self, db_name: &str) -> PgResult<()> {
276 Postgres::drop_database(&self.full_db_uri(db_name))
277 .map_err(|e| PgEmbedError {
278 error_type: PgEmbedErrorType::PgTaskJoinError,
279 source: Some(Box::new(e)),
280 message: None,
281 })
282 .await?;
283 Ok(())
284 }
285
286 #[cfg(any(
290 feature = "rt_tokio_migrate",
291 feature = "rt_async_std_migrate",
292 feature = "rt_actix_migrate"
293 ))]
294 pub async fn database_exists(&self, db_name: &str) -> PgResult<bool> {
295 let result = Postgres::database_exists(&self.full_db_uri(db_name))
296 .map_err(|e| PgEmbedError {
297 error_type: PgEmbedErrorType::PgTaskJoinError,
298 source: Some(Box::new(e)),
299 message: None,
300 })
301 .await?;
302 Ok(result)
303 }
304
305 pub fn full_db_uri(&self, db_name: &str) -> String {
311 format!("{}/{}", &self.db_uri, db_name)
312 }
313
314 #[cfg(any(
318 feature = "rt_tokio_migrate",
319 feature = "rt_async_std_migrate",
320 feature = "rt_actix_migrate"
321 ))]
322 pub async fn migrate(&self, db_name: &str) -> PgResult<()> {
323 if let Some(migration_dir) = &self.pg_settings.migration_dir {
324 let m = Migrator::new(migration_dir.as_path())
325 .map_err(|e| PgEmbedError {
326 error_type: PgEmbedErrorType::MigrationError,
327 source: Some(Box::new(e)),
328 message: None,
329 })
330 .await?;
331 let pool = PgPoolOptions::new()
332 .connect(&self.full_db_uri(db_name))
333 .map_err(|e| PgEmbedError {
334 error_type: PgEmbedErrorType::SqlQueryError,
335 source: Some(Box::new(e)),
336 message: None,
337 })
338 .await?;
339 m.run(&pool)
340 .map_err(|e| PgEmbedError {
341 error_type: PgEmbedErrorType::MigrationError,
342 source: Some(Box::new(e)),
343 message: None,
344 })
345 .await?;
346 }
347 Ok(())
348 }
349}