use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{Error, Result};
#[async_trait]
pub trait StorageProvider: Send + Sync + 'static {
async fn probe(&self) -> Result<()>;
fn system(&self) -> &'static str {
"custom"
}
}
#[derive(Clone, Default)]
pub struct State {
pg: Option<sqlx::PgPool>,
redis: Option<Arc<redis::Client>>,
storage: Option<Arc<dyn StorageProvider>>,
}
impl State {
pub async fn with_storage<S: StorageProvider>(mut self, storage: S) -> Result<Self> {
storage.probe().await?;
self.storage = Some(Arc::new(storage));
Ok(self)
}
pub fn set_storage(&mut self, storage: Arc<dyn StorageProvider>) {
self.storage = Some(storage);
}
pub fn has_storage(&self) -> bool {
self.storage.is_some()
}
pub fn storage(&self) -> Result<&Arc<dyn StorageProvider>> {
self.storage.as_ref().ok_or_else(|| {
Error::Config(
"storage requested but no provider was wired at startup \
(scaffold with --with-storage to enable)"
.into(),
)
})
}
pub async fn from_env() -> Result<Self> {
let pg = match std::env::var("DATABASE_URL") {
Ok(url) => {
tracing::info!(target: "tonin::state", "connecting to postgres");
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(default_pg_max_conns())
.connect(&url)
.await
.map_err(|e| Error::Config(format!("postgres connect failed: {e}")))?;
Some(pool)
}
Err(_) => None,
};
let redis = match std::env::var("REDIS_URL") {
Ok(url) => {
tracing::info!(target: "tonin::state", "connecting to redis");
let client = redis::Client::open(url)
.map_err(|e| Error::Config(format!("redis client init: {e}")))?;
let mut conn = client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::Config(format!("redis connect failed: {e}")))?;
let _: String = redis::cmd("PING")
.query_async(&mut conn)
.await
.map_err(|e| Error::Config(format!("redis PING failed: {e}")))?;
Some(Arc::new(client))
}
Err(_) => None,
};
Ok(Self {
pg,
redis,
storage: None,
})
}
pub fn pg(&self) -> Result<&sqlx::PgPool> {
self.pg.as_ref().ok_or_else(|| {
Error::Config("postgres requested but DATABASE_URL was not set at startup".into())
})
}
pub fn has_pg(&self) -> bool {
self.pg.is_some()
}
pub fn redis(&self) -> Result<&redis::Client> {
self.redis.as_deref().ok_or_else(|| {
Error::Config("redis requested but REDIS_URL was not set at startup".into())
})
}
pub fn has_redis(&self) -> bool {
self.redis.is_some()
}
}
fn default_pg_max_conns() -> u32 {
std::env::var("TONIN_PG_MAX_CONNECTIONS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockStorage {
probes: AtomicUsize,
probe_fails: bool,
}
#[async_trait]
impl StorageProvider for MockStorage {
async fn probe(&self) -> Result<()> {
self.probes.fetch_add(1, Ordering::SeqCst);
if self.probe_fails {
Err(Error::Config("mock probe failure".into()))
} else {
Ok(())
}
}
fn system(&self) -> &'static str {
"memory"
}
}
#[tokio::test]
async fn empty_state_when_no_env_vars() {
if std::env::var("DATABASE_URL").is_ok() || std::env::var("REDIS_URL").is_ok() {
return;
}
let state = State::from_env().await.unwrap();
assert!(!state.has_pg());
assert!(!state.has_redis());
assert!(!state.has_storage());
assert!(state.pg().is_err());
assert!(state.redis().is_err());
assert!(state.storage().is_err());
}
#[tokio::test]
async fn with_storage_runs_probe() {
let state = State::default();
let storage = MockStorage {
probes: AtomicUsize::new(0),
probe_fails: false,
};
let state = state.with_storage(storage).await.unwrap();
assert!(state.has_storage());
assert_eq!(state.storage().unwrap().system(), "memory");
}
#[tokio::test]
async fn with_storage_propagates_probe_failure() {
let state = State::default();
let storage = MockStorage {
probes: AtomicUsize::new(0),
probe_fails: true,
};
match state.with_storage(storage).await {
Ok(_) => panic!("expected probe failure to propagate"),
Err(Error::Config(_)) => {}
Err(other) => panic!("expected Config, got {other:?}"),
}
}
}