use std::sync::Arc;
use std::time::Duration;
use sqlx::postgres::{PgPool, PgPoolOptions};
use crate::error::{PersistenceError, PersistenceResult};
const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_mins(5);
const DEFAULT_MAX_CONNECTIONS: u32 = 16;
#[derive(Clone)]
pub struct PostgresPersistence {
pool: Arc<PgPool>,
}
impl PostgresPersistence {
fn from_pool(pool: PgPool) -> Self {
Self {
pool: Arc::new(pool),
}
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn migrate(&self) -> PersistenceResult<()> {
sqlx::migrate!("./migrations")
.run(&*self.pool)
.await
.map_err(|e| PersistenceError::Backend(format!("migrate: {e}")))?;
Ok(())
}
pub fn lock(&self) -> super::PostgresLock {
super::PostgresLock::new(Arc::clone(&self.pool))
}
pub fn checkpointer<S>(&self) -> super::PostgresCheckpointer<S>
where
S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
{
super::PostgresCheckpointer::new(Arc::clone(&self.pool))
}
pub fn store<V>(&self) -> super::PostgresStore<V>
where
V: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
{
super::PostgresStore::new(Arc::clone(&self.pool))
}
pub fn session_log(&self) -> super::PostgresSessionLog {
super::PostgresSessionLog::new(Arc::clone(&self.pool))
}
}
#[derive(Debug)]
#[must_use]
pub struct PostgresPersistenceBuilder {
url: Option<String>,
max_connections: u32,
acquire_timeout: Duration,
idle_timeout: Duration,
test_before_acquire: bool,
}
impl PostgresPersistence {
pub fn builder() -> PostgresPersistenceBuilder {
PostgresPersistenceBuilder {
url: None,
max_connections: DEFAULT_MAX_CONNECTIONS,
acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
idle_timeout: DEFAULT_IDLE_TIMEOUT,
test_before_acquire: true,
}
}
}
impl PostgresPersistenceBuilder {
pub fn with_connection_string(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub const fn with_max_connections(mut self, n: u32) -> Self {
self.max_connections = n;
self
}
pub const fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub const fn with_idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
pub const fn with_test_before_acquire(mut self, on: bool) -> Self {
self.test_before_acquire = on;
self
}
pub async fn connect(self) -> PersistenceResult<PostgresPersistence> {
let url = self
.url
.ok_or_else(|| PersistenceError::Config("connection_string is required".into()))?;
let pool = PgPoolOptions::new()
.max_connections(self.max_connections)
.acquire_timeout(self.acquire_timeout)
.idle_timeout(Some(self.idle_timeout))
.test_before_acquire(self.test_before_acquire)
.connect(&url)
.await
.map_err(|e| PersistenceError::Backend(format!("connect: {e}")))?;
Ok(PostgresPersistence::from_pool(pool))
}
pub async fn connect_and_migrate(self) -> PersistenceResult<PostgresPersistence> {
let p = self.connect().await?;
p.migrate().await?;
Ok(p)
}
}