1use crate::error::Error::{DatabaseInitializationError, DatabaseStartError, DatabaseStopError};
2use crate::error::Result;
3use crate::settings::{BOOTSTRAP_DATABASE, BOOTSTRAP_SUPERUSER, Settings};
4use postgresql_archive::get_version;
5use postgresql_archive::{ExactVersion, ExactVersionReq};
6use postgresql_archive::{extract, get_archive};
7#[cfg(feature = "tokio")]
8use postgresql_commands::AsyncCommandExecutor;
9use postgresql_commands::CommandBuilder;
10#[cfg(not(feature = "tokio"))]
11use postgresql_commands::CommandExecutor;
12use postgresql_commands::initdb::InitDbBuilder;
13use postgresql_commands::pg_ctl::Mode::{Start, Stop};
14use postgresql_commands::pg_ctl::PgCtlBuilder;
15use postgresql_commands::pg_ctl::ShutdownMode::Fast;
16use semver::Version;
17use sqlx::{PgPool, Row};
18use std::fs::{read_dir, remove_dir_all, remove_file};
19use std::io::prelude::*;
20use std::net::TcpListener;
21use std::path::PathBuf;
22use tracing::{debug, instrument};
23
24use crate::Error::{CreateDatabaseError, DatabaseExistsError, DropDatabaseError};
25
26const PGDATABASE: &str = "PGDATABASE";
27
28#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum Status {
31 NotInstalled,
33 Installed,
35 Started,
37 Stopped,
39}
40
41#[derive(Clone, Debug)]
43pub struct PostgreSQL {
44 settings: Settings,
45}
46
47impl PostgreSQL {
49 #[must_use]
51 pub fn new(settings: Settings) -> Self {
52 let mut postgresql = PostgreSQL { settings };
53
54 if let Some(version) = postgresql.settings.version.exact_version() {
59 let path = &postgresql.settings.installation_dir;
60 let version_string = version.to_string();
61
62 if !path.ends_with(&version_string) {
63 postgresql.settings.installation_dir =
64 postgresql.settings.installation_dir.join(version_string);
65 }
66 }
67
68 postgresql
69 }
70
71 #[instrument(level = "debug", skip(self))]
73 pub fn status(&self) -> Status {
74 if self.is_running() {
75 Status::Started
76 } else if self.is_initialized() {
77 Status::Stopped
78 } else if self.installed_dir().is_some() {
79 Status::Installed
80 } else {
81 Status::NotInstalled
82 }
83 }
84
85 #[must_use]
87 pub fn settings(&self) -> &Settings {
88 &self.settings
89 }
90
91 fn installed_dir(&self) -> Option<PathBuf> {
96 let path = &self.settings.installation_dir;
97 let maybe_path_version = path
98 .file_name()
99 .and_then(|file_name| Version::parse(&file_name.to_string_lossy()).ok());
100 if let Some(path_version) = maybe_path_version {
102 if self.settings.version.matches(&path_version) && path.exists() {
103 return Some(path.clone());
104 }
105 }
106
107 let mut versions = read_dir(path)
109 .ok()?
110 .filter_map(|entry| {
111 let Some(entry) = entry.ok() else {
112 return None;
114 };
115 if !entry.file_type().ok()?.is_dir() {
117 return None;
118 }
119 let file_name = entry.file_name();
120 let version = Version::parse(&file_name.to_string_lossy()).ok()?;
121 if self.settings.version.matches(&version) {
122 Some((version, entry.path()))
123 } else {
124 None
125 }
126 })
127 .collect::<Vec<_>>();
128 versions.sort_by(|(a, _), (b, _)| b.cmp(a));
130 let version_path = versions.first().map(|(_, path)| path.clone());
132 version_path
133 }
134
135 fn is_initialized(&self) -> bool {
137 self.settings.data_dir.join("postgresql.conf").exists()
138 }
139
140 fn is_running(&self) -> bool {
142 let pid_file = self.settings.data_dir.join("postmaster.pid");
143 pid_file.exists()
144 }
145
146 #[instrument(skip(self))]
150 pub async fn setup(&mut self) -> Result<()> {
151 match self.installed_dir() {
152 Some(installed_dir) => {
153 self.settings.installation_dir = installed_dir;
154 }
155 None => {
156 self.install().await?;
157 }
158 }
159 if !self.is_initialized() {
160 self.initialize().await?;
161 }
162
163 Ok(())
164 }
165
166 #[instrument(skip(self))]
172 async fn install(&mut self) -> Result<()> {
173 debug!(
174 "Starting installation process for version {}",
175 self.settings.version
176 );
177
178 if self.settings.version.exact_version().is_none() {
182 let version = get_version(&self.settings.releases_url, &self.settings.version).await?;
183 self.settings.version = version.exact_version_req()?;
184 self.settings.installation_dir =
185 self.settings.installation_dir.join(version.to_string());
186 }
187
188 if self.settings.installation_dir.exists() {
189 debug!("Installation directory already exists");
190 return Ok(());
191 }
192
193 let url = &self.settings.releases_url;
194
195 #[cfg(feature = "bundled")]
196 let (version, bytes) = if *crate::settings::ARCHIVE_VERSION == self.settings.version {
200 debug!("Using bundled installation archive");
201 (
202 self.settings.version.clone(),
203 crate::settings::ARCHIVE.to_vec(),
204 )
205 } else {
206 let (version, bytes) = get_archive(url, &self.settings.version).await?;
207 (version.exact_version_req()?, bytes)
208 };
209
210 #[cfg(not(feature = "bundled"))]
211 let (version, bytes) = {
212 let (version, bytes) = get_archive(url, &self.settings.version).await?;
213 (version.exact_version_req()?, bytes)
214 };
215
216 self.settings.version = version;
217 extract(url, &bytes, &self.settings.installation_dir).await?;
218
219 debug!(
220 "Installed PostgreSQL version {} to {}",
221 self.settings.version,
222 self.settings.installation_dir.to_string_lossy()
223 );
224
225 Ok(())
226 }
227
228 #[instrument(skip(self))]
231 async fn initialize(&mut self) -> Result<()> {
232 if !self.settings.password_file.exists() {
233 let mut file = std::fs::File::create(&self.settings.password_file)?;
234 file.write_all(self.settings.password.as_bytes())?;
235 }
236
237 debug!(
238 "Initializing database {}",
239 self.settings.data_dir.to_string_lossy()
240 );
241
242 let initdb = InitDbBuilder::from(&self.settings)
243 .pgdata(&self.settings.data_dir)
244 .username(BOOTSTRAP_SUPERUSER)
245 .auth("password")
246 .pwfile(&self.settings.password_file)
247 .encoding("UTF8");
248
249 match self.execute_command(initdb).await {
250 Ok((_stdout, _stderr)) => {
251 debug!(
252 "Initialized database {}",
253 self.settings.data_dir.to_string_lossy()
254 );
255 Ok(())
256 }
257 Err(error) => Err(DatabaseInitializationError(error.to_string())),
258 }
259 }
260
261 #[instrument(skip(self))]
264 pub async fn start(&mut self) -> Result<()> {
265 if self.settings.port == 0 {
266 let listener = TcpListener::bind(("0.0.0.0", 0))?;
267 self.settings.port = listener.local_addr()?.port();
268 }
269
270 debug!(
271 "Starting database {} on port {}",
272 self.settings.data_dir.to_string_lossy(),
273 self.settings.port
274 );
275 let start_log = self.settings.data_dir.join("start.log");
276 let mut options = Vec::new();
277 options.push(format!("-F -p {}", self.settings.port));
278 for (key, value) in &self.settings.configuration {
279 options.push(format!("-c {key}={value}"));
280 }
281 let pg_ctl = PgCtlBuilder::from(&self.settings)
282 .env(PGDATABASE, "")
283 .mode(Start)
284 .pgdata(&self.settings.data_dir)
285 .log(start_log)
286 .options(options.as_slice())
287 .wait();
288
289 match self.execute_command(pg_ctl).await {
290 Ok((_stdout, _stderr)) => {
291 debug!(
292 "Started database {} on port {}",
293 self.settings.data_dir.to_string_lossy(),
294 self.settings.port
295 );
296 Ok(())
297 }
298 Err(error) => Err(DatabaseStartError(error.to_string())),
299 }
300 }
301
302 #[instrument(skip(self))]
304 pub async fn stop(&self) -> Result<()> {
305 debug!(
306 "Stopping database {}",
307 self.settings.data_dir.to_string_lossy()
308 );
309 let pg_ctl = PgCtlBuilder::from(&self.settings)
310 .mode(Stop)
311 .pgdata(&self.settings.data_dir)
312 .shutdown_mode(Fast)
313 .wait();
314
315 match self.execute_command(pg_ctl).await {
316 Ok((_stdout, _stderr)) => {
317 debug!(
318 "Stopped database {}",
319 self.settings.data_dir.to_string_lossy()
320 );
321 Ok(())
322 }
323 Err(error) => Err(DatabaseStopError(error.to_string())),
324 }
325 }
326
327 async fn get_pool(&self) -> Result<PgPool> {
329 let mut settings = self.settings.clone();
330 settings.username = BOOTSTRAP_SUPERUSER.to_string();
331 let database_url = settings.url(BOOTSTRAP_DATABASE);
332 let pool = PgPool::connect(database_url.as_str()).await?;
333 Ok(pool)
334 }
335
336 #[instrument(skip(self))]
338 pub async fn create_database<S>(&self, database_name: S) -> Result<()>
339 where
340 S: AsRef<str> + std::fmt::Debug,
341 {
342 let database_name = database_name.as_ref();
343 debug!(
344 "Creating database {database_name} for {host}:{port}",
345 host = self.settings.host,
346 port = self.settings.port
347 );
348 let pool = self.get_pool().await?;
349 sqlx::query(format!("CREATE DATABASE \"{database_name}\"").as_str())
350 .execute(&pool)
351 .await
352 .map_err(|error| CreateDatabaseError(error.to_string()))?;
353 pool.close().await;
354 debug!(
355 "Created database {database_name} for {host}:{port}",
356 host = self.settings.host,
357 port = self.settings.port
358 );
359 Ok(())
360 }
361
362 #[instrument(skip(self))]
364 pub async fn database_exists<S>(&self, database_name: S) -> Result<bool>
365 where
366 S: AsRef<str> + std::fmt::Debug,
367 {
368 let database_name = database_name.as_ref();
369 debug!(
370 "Checking if database {database_name} exists for {host}:{port}",
371 host = self.settings.host,
372 port = self.settings.port
373 );
374 let pool = self.get_pool().await?;
375 let row = sqlx::query("SELECT COUNT(*) FROM pg_database WHERE datname = $1")
376 .bind(database_name.to_string())
377 .fetch_one(&pool)
378 .await
379 .map_err(|error| DatabaseExistsError(error.to_string()))?;
380 let count: i64 = row.get(0);
381 pool.close().await;
382
383 Ok(count == 1)
384 }
385
386 #[instrument(skip(self))]
388 pub async fn drop_database<S>(&self, database_name: S) -> Result<()>
389 where
390 S: AsRef<str> + std::fmt::Debug,
391 {
392 let database_name = database_name.as_ref();
393 debug!(
394 "Dropping database {database_name} for {host}:{port}",
395 host = self.settings.host,
396 port = self.settings.port
397 );
398 let pool = self.get_pool().await?;
399 sqlx::query(format!("DROP DATABASE IF EXISTS \"{database_name}\"").as_str())
400 .execute(&pool)
401 .await
402 .map_err(|error| DropDatabaseError(error.to_string()))?;
403 pool.close().await;
404 debug!(
405 "Dropped database {database_name} for {host}:{port}",
406 host = self.settings.host,
407 port = self.settings.port
408 );
409 Ok(())
410 }
411
412 #[cfg(not(feature = "tokio"))]
413 #[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
415 async fn execute_command<B: CommandBuilder>(
416 &self,
417 command_builder: B,
418 ) -> postgresql_commands::Result<(String, String)> {
419 let mut command = command_builder.build();
420 command.execute()
421 }
422
423 #[cfg(feature = "tokio")]
424 #[instrument(level = "debug", skip(self, command_builder), fields(program = ?command_builder.get_program()))]
426 async fn execute_command<B: CommandBuilder>(
427 &self,
428 command_builder: B,
429 ) -> postgresql_commands::Result<(String, String)> {
430 let mut command = command_builder.build_tokio();
431 command.execute(self.settings.timeout).await
432 }
433}
434
435impl Default for PostgreSQL {
437 fn default() -> Self {
438 Self::new(Settings::default())
439 }
440}
441
442impl Drop for PostgreSQL {
444 fn drop(&mut self) {
445 if self.status() == Status::Started {
446 let mut pg_ctl = PgCtlBuilder::from(&self.settings)
447 .mode(Stop)
448 .pgdata(&self.settings.data_dir)
449 .shutdown_mode(Fast)
450 .wait()
451 .build();
452
453 let _ = pg_ctl.output();
454 }
455
456 if self.settings.temporary {
457 let _ = remove_dir_all(&self.settings.data_dir);
458 let _ = remove_file(&self.settings.password_file);
459 }
460 }
461}