use crate::{
SqlColdBackend, SqlColdError,
backend::{DEFAULT_READ_TIMEOUT, DEFAULT_WRITE_TIMEOUT},
};
use signet_cold::ColdConnect;
use sqlx::pool::PoolOptions;
use std::time::Duration;
#[derive(Debug, thiserror::Error)]
pub enum SqlConnectorError {
#[error("missing environment variable: {0}")]
MissingEnvVar(&'static str),
#[error("cold storage initialization failed: {0}")]
ColdInit(#[from] SqlColdError),
}
#[cfg(any(feature = "sqlite", feature = "postgres"))]
#[derive(Debug, Clone)]
pub struct SqlConnector {
url: String,
pool_opts: PoolOptions<sqlx::Any>,
read_timeout: Duration,
write_timeout: Duration,
}
#[cfg(any(feature = "sqlite", feature = "postgres"))]
impl SqlConnector {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
pool_opts: PoolOptions::new(),
read_timeout: DEFAULT_READ_TIMEOUT,
write_timeout: DEFAULT_WRITE_TIMEOUT,
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn with_pool_options(mut self, pool_opts: PoolOptions<sqlx::Any>) -> Self {
self.pool_opts = pool_opts;
self
}
pub fn with_max_connections(mut self, n: u32) -> Self {
self.pool_opts = self.pool_opts.max_connections(n);
self
}
pub fn with_min_connections(mut self, n: u32) -> Self {
self.pool_opts = self.pool_opts.min_connections(n);
self
}
pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
self.pool_opts = self.pool_opts.acquire_timeout(timeout);
self
}
pub fn with_max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
self.pool_opts = self.pool_opts.max_lifetime(lifetime);
self
}
pub fn with_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
self.pool_opts = self.pool_opts.idle_timeout(timeout);
self
}
#[must_use]
pub fn with_read_timeout(mut self, d: Duration) -> Self {
assert!(d.as_millis() >= 1, "read_timeout must be >= 1ms (got {d:?})");
self.read_timeout = d;
self
}
#[must_use]
pub fn with_write_timeout(mut self, d: Duration) -> Self {
assert!(d.as_millis() >= 1, "write_timeout must be >= 1ms (got {d:?})");
self.write_timeout = d;
self
}
pub fn from_env(env_var: &'static str) -> Result<Self, SqlConnectorError> {
let url = std::env::var(env_var).map_err(|_| SqlConnectorError::MissingEnvVar(env_var))?;
Ok(Self::new(url))
}
}
#[cfg(any(feature = "sqlite", feature = "postgres"))]
impl ColdConnect for SqlConnector {
type Cold = SqlColdBackend;
type Error = SqlColdError;
fn connect(&self) -> impl std::future::Future<Output = Result<Self::Cold, Self::Error>> + Send {
let url = self.url.clone();
let pool_opts = self.pool_opts.clone();
let read_timeout = self.read_timeout;
let write_timeout = self.write_timeout;
async move {
let backend = SqlColdBackend::connect_with(&url, pool_opts).await?;
Ok(backend.with_read_timeout(read_timeout).with_write_timeout(write_timeout))
}
}
}