#![deny(missing_docs)]
use nidus_core::NidusError;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, SqlxError>;
#[derive(Debug, Error)]
pub enum SqlxError {
#[error(transparent)]
Sqlx(#[from] sqlx::Error),
#[error(transparent)]
Nidus(#[from] NidusError),
#[cfg(feature = "nidus-config")]
#[error(transparent)]
Config(#[from] nidus_config::ConfigError),
}
#[cfg(feature = "sqlite")]
mod sqlite {
#[cfg(feature = "observability")]
use std::time::Instant;
use super::Result;
use nidus_core::Container;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SqlitePoolConfig {
database_url: String,
max_connections: Option<u32>,
}
impl SqlitePoolConfig {
pub fn new(database_url: impl Into<String>) -> Self {
Self {
database_url: database_url.into(),
max_connections: None,
}
}
pub fn with_max_connections(mut self, max_connections: u32) -> Self {
self.max_connections = Some(max_connections);
self
}
pub fn database_url(&self) -> &str {
&self.database_url
}
pub fn max_connections(&self) -> Option<u32> {
self.max_connections
}
#[cfg(feature = "nidus-config")]
pub fn from_config_path<I, S>(config: &nidus_config::Config, path: I) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
#[derive(serde::Deserialize)]
struct RawConfig {
url: String,
max_connections: Option<u32>,
}
let raw: RawConfig = config.get_required_path_typed(path)?;
let mut settings = Self::new(raw.url);
if let Some(max_connections) = raw.max_connections {
settings = settings.with_max_connections(max_connections);
}
Ok(settings)
}
}
#[derive(Clone, Debug)]
pub struct SqlitePoolBuilder {
config: SqlitePoolConfig,
#[cfg(feature = "observability")]
observer: Option<nidus_observability::ObservabilityAdapterObserver>,
}
impl SqlitePoolBuilder {
pub fn new() -> Self {
Self {
config: SqlitePoolConfig::new("sqlite::memory:"),
#[cfg(feature = "observability")]
observer: None,
}
}
pub fn config(mut self, config: SqlitePoolConfig) -> Self {
self.config = config;
self
}
pub fn database_url(mut self, database_url: impl Into<String>) -> Self {
self.config.database_url = database_url.into();
self
}
pub fn max_connections(mut self, max_connections: u32) -> Self {
self.config.max_connections = Some(max_connections);
self
}
#[cfg(feature = "observability")]
pub fn observability(
mut self,
observer: nidus_observability::ObservabilityAdapterObserver,
) -> Self {
self.observer = Some(observer);
self
}
pub async fn connect(self) -> Result<SqlitePoolProvider> {
#[cfg(feature = "observability")]
let observer = self.observer;
let mut options = sqlx::sqlite::SqlitePoolOptions::new();
if let Some(max_connections) = self.config.max_connections {
options = options.max_connections(max_connections);
}
#[cfg(feature = "observability")]
let started_at = Instant::now();
let pool = options.connect(&self.config.database_url).await;
#[cfg(feature = "observability")]
record_adapter_operation(
&observer,
"connect",
nidus_observability::OperationStatus::from(pool.is_ok()),
started_at,
);
let pool = pool?;
Ok(SqlitePoolProvider {
pool,
#[cfg(feature = "observability")]
observer,
})
}
pub async fn register(self, container: &mut Container) -> Result<()> {
let provider = self.connect().await?;
container.register_singleton(provider)?;
Ok(())
}
}
impl Default for SqlitePoolBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct SqlitePoolProvider {
pool: sqlx::SqlitePool,
#[cfg(feature = "observability")]
observer: Option<nidus_observability::ObservabilityAdapterObserver>,
}
impl SqlitePoolProvider {
pub fn builder() -> SqlitePoolBuilder {
SqlitePoolBuilder::new()
}
pub fn from_pool(pool: sqlx::SqlitePool) -> Self {
Self {
pool,
#[cfg(feature = "observability")]
observer: None,
}
}
pub fn pool(&self) -> &sqlx::SqlitePool {
&self.pool
}
pub fn into_pool(self) -> sqlx::SqlitePool {
self.pool
}
#[cfg(feature = "health")]
pub async fn health_status(&self) -> nidus_http::health::HealthStatus {
#[cfg(feature = "observability")]
let started_at = Instant::now();
let result = sqlx::query("SELECT 1").execute(&self.pool).await;
#[cfg(feature = "observability")]
record_adapter_operation(
&self.observer,
"health",
nidus_observability::OperationStatus::from(result.is_ok()),
started_at,
);
match result {
Ok(_) => nidus_http::health::HealthStatus::up(),
Err(error) => nidus_http::health::HealthStatus::down(error.to_string()),
}
}
#[cfg(feature = "health")]
pub fn register_ready_check(
self: std::sync::Arc<Self>,
registry: nidus_http::health::HealthRegistry,
name: impl Into<String>,
) -> nidus_http::health::HealthRegistry {
registry.ready_check(name, move || {
let provider = std::sync::Arc::clone(&self);
async move { provider.health_status().await }
})
}
}
#[cfg(feature = "observability")]
fn record_adapter_operation(
observer: &Option<nidus_observability::ObservabilityAdapterObserver>,
operation: &'static str,
status: nidus_observability::OperationStatus,
started_at: Instant,
) {
if let Some(observer) = observer {
observer.record("nidus-sqlx", operation, status, started_at.elapsed());
}
}
}
#[cfg(feature = "sqlite")]
pub use sqlite::{SqlitePoolBuilder, SqlitePoolConfig, SqlitePoolProvider};
#[cfg(feature = "postgres")]
mod postgres {
#[cfg(feature = "observability")]
use std::time::Instant;
use super::Result;
use nidus_core::Container;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PostgresPoolConfig {
database_url: String,
max_connections: Option<u32>,
min_connections: Option<u32>,
}
impl PostgresPoolConfig {
pub fn new(database_url: impl Into<String>) -> Self {
Self {
database_url: database_url.into(),
max_connections: None,
min_connections: None,
}
}
pub fn with_max_connections(mut self, max_connections: u32) -> Self {
self.max_connections = Some(max_connections);
self
}
pub fn with_min_connections(mut self, min_connections: u32) -> Self {
self.min_connections = Some(min_connections);
self
}
pub fn database_url(&self) -> &str {
&self.database_url
}
pub fn max_connections(&self) -> Option<u32> {
self.max_connections
}
pub fn min_connections(&self) -> Option<u32> {
self.min_connections
}
#[cfg(feature = "nidus-config")]
pub fn from_config_path<I, S>(config: &nidus_config::Config, path: I) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
#[derive(serde::Deserialize)]
struct RawConfig {
url: String,
max_connections: Option<u32>,
min_connections: Option<u32>,
}
let raw: RawConfig = config.get_required_path_typed(path)?;
let mut settings = Self::new(raw.url);
if let Some(max_connections) = raw.max_connections {
settings = settings.with_max_connections(max_connections);
}
if let Some(min_connections) = raw.min_connections {
settings = settings.with_min_connections(min_connections);
}
Ok(settings)
}
}
#[derive(Clone, Debug)]
pub struct PostgresPoolBuilder {
config: PostgresPoolConfig,
#[cfg(feature = "observability")]
observer: Option<nidus_observability::ObservabilityAdapterObserver>,
}
impl PostgresPoolBuilder {
pub fn new(database_url: impl Into<String>) -> Self {
Self {
config: PostgresPoolConfig::new(database_url),
#[cfg(feature = "observability")]
observer: None,
}
}
pub fn config(mut self, config: PostgresPoolConfig) -> Self {
self.config = config;
self
}
pub fn database_url(mut self, database_url: impl Into<String>) -> Self {
self.config.database_url = database_url.into();
self
}
pub fn max_connections(mut self, max_connections: u32) -> Self {
self.config.max_connections = Some(max_connections);
self
}
pub fn min_connections(mut self, min_connections: u32) -> Self {
self.config.min_connections = Some(min_connections);
self
}
#[cfg(feature = "observability")]
pub fn observability(
mut self,
observer: nidus_observability::ObservabilityAdapterObserver,
) -> Self {
self.observer = Some(observer);
self
}
pub async fn connect(self) -> Result<PostgresPoolProvider> {
#[cfg(feature = "observability")]
let observer = self.observer;
let mut options = sqlx::postgres::PgPoolOptions::new();
if let Some(max_connections) = self.config.max_connections {
options = options.max_connections(max_connections);
}
if let Some(min_connections) = self.config.min_connections {
options = options.min_connections(min_connections);
}
#[cfg(feature = "observability")]
let started_at = Instant::now();
let pool = options.connect(&self.config.database_url).await;
#[cfg(feature = "observability")]
record_adapter_operation(
&observer,
"connect",
nidus_observability::OperationStatus::from(pool.is_ok()),
started_at,
);
let pool = pool?;
Ok(PostgresPoolProvider {
pool,
#[cfg(feature = "observability")]
observer,
})
}
pub async fn register(self, container: &mut Container) -> Result<()> {
let provider = self.connect().await?;
container.register_singleton(provider)?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct PostgresPoolProvider {
pool: sqlx::PgPool,
#[cfg(feature = "observability")]
observer: Option<nidus_observability::ObservabilityAdapterObserver>,
}
impl PostgresPoolProvider {
pub fn builder(database_url: impl Into<String>) -> PostgresPoolBuilder {
PostgresPoolBuilder::new(database_url)
}
pub fn from_pool(pool: sqlx::PgPool) -> Self {
Self {
pool,
#[cfg(feature = "observability")]
observer: None,
}
}
pub fn pool(&self) -> &sqlx::PgPool {
&self.pool
}
pub fn into_pool(self) -> sqlx::PgPool {
self.pool
}
#[cfg(feature = "health")]
pub async fn health_status(&self) -> nidus_http::health::HealthStatus {
#[cfg(feature = "observability")]
let started_at = Instant::now();
let result = sqlx::query("SELECT 1").execute(&self.pool).await;
#[cfg(feature = "observability")]
record_adapter_operation(
&self.observer,
"health",
nidus_observability::OperationStatus::from(result.is_ok()),
started_at,
);
match result {
Ok(_) => nidus_http::health::HealthStatus::up(),
Err(error) => nidus_http::health::HealthStatus::down(error.to_string()),
}
}
#[cfg(feature = "health")]
pub fn register_ready_check(
self: std::sync::Arc<Self>,
registry: nidus_http::health::HealthRegistry,
name: impl Into<String>,
) -> nidus_http::health::HealthRegistry {
registry.ready_check(name, move || {
let provider = std::sync::Arc::clone(&self);
async move { provider.health_status().await }
})
}
}
#[cfg(feature = "observability")]
fn record_adapter_operation(
observer: &Option<nidus_observability::ObservabilityAdapterObserver>,
operation: &'static str,
status: nidus_observability::OperationStatus,
started_at: Instant,
) {
if let Some(observer) = observer {
observer.record("nidus-sqlx", operation, status, started_at.elapsed());
}
}
}
#[cfg(feature = "postgres")]
pub use postgres::{PostgresPoolBuilder, PostgresPoolConfig, PostgresPoolProvider};