use std::time::Duration;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sa_token_adapter::storage::{SaStorage, StorageError, StorageResult};
use sqlx::{Pool, Postgres};
#[derive(Clone)]
pub struct DatabaseStorage {
pool: Pool<Postgres>,
}
impl DatabaseStorage {
pub async fn new(database_url: &str) -> StorageResult<Self> {
let pool = Pool::<Postgres>::connect(database_url)
.await
.map_err(|e| StorageError::ConnectionError(e.to_string()))?;
let storage = Self { pool };
storage.migrate().await?;
Ok(storage)
}
pub fn from_pool(pool: Pool<Postgres>) -> Self {
Self { pool }
}
pub async fn migrate(&self) -> StorageResult<()> {
let ddl = include_str!("../migrations/001_sa_token_storage.sql");
for statement in ddl.split(';').map(str::trim).filter(|s| !s.is_empty()) {
sqlx::query(statement)
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
}
Ok(())
}
async fn delete_expired(&self, key: &str) -> StorageResult<()> {
sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
Ok(())
}
fn is_expired(expire_at: Option<DateTime<Utc>>) -> bool {
expire_at.is_some_and(|t| Utc::now() > t)
}
}
pub fn pattern_to_like(pattern: &str) -> String {
let mut out = String::with_capacity(pattern.len());
for ch in pattern.chars() {
match ch {
'*' => out.push('%'),
'%' | '_' | '\\' => {
out.push('\\');
out.push(ch);
}
other => out.push(other),
}
}
out
}
#[async_trait]
impl SaStorage for DatabaseStorage {
async fn get(&self, key: &str) -> StorageResult<Option<String>> {
let row: Option<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
"SELECT value, expire_at FROM sa_token_storage WHERE key = $1",
)
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
match row {
Some((_value, expire_at)) if Self::is_expired(expire_at) => {
self.delete_expired(key).await?;
Ok(None)
}
Some((value, _)) => Ok(Some(value)),
None => Ok(None),
}
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
let expire_at: Option<DateTime<Utc>> = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
sqlx::query(
r#"
INSERT INTO sa_token_storage (key, value, expire_at, updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value,
expire_at = EXCLUDED.expire_at,
updated_at = NOW()
"#,
)
.bind(key)
.bind(value)
.bind(expire_at)
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
Ok(())
}
async fn delete(&self, key: &str) -> StorageResult<()> {
sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
Ok(())
}
async fn exists(&self, key: &str) -> StorageResult<bool> {
Ok(self.get(key).await?.is_some())
}
async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
let expire_at = Utc::now() + chrono::Duration::from_std(ttl).unwrap();
let updated = sqlx::query(
"UPDATE sa_token_storage SET expire_at = $1, updated_at = NOW() WHERE key = $2",
)
.bind(expire_at)
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?
.rows_affected();
if updated == 0 {
return Err(StorageError::KeyNotFound(key.to_string()));
}
Ok(())
}
async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
let row: Option<Option<DateTime<Utc>>> = sqlx::query_scalar(
"SELECT expire_at FROM sa_token_storage WHERE key = $1",
)
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
match row {
None => Ok(None),
Some(None) => Ok(None),
Some(Some(expire_at)) if Self::is_expired(Some(expire_at)) => {
self.delete_expired(key).await?;
Ok(None)
}
Some(Some(expire_at)) => {
let remaining = expire_at.signed_duration_since(Utc::now());
if remaining.num_milliseconds() <= 0 {
Ok(Some(Duration::ZERO))
} else {
Ok(Some(
remaining
.to_std()
.unwrap_or(Duration::ZERO),
))
}
}
}
}
async fn clear(&self) -> StorageResult<()> {
sqlx::query("TRUNCATE TABLE sa_token_storage")
.execute(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
Ok(())
}
async fn keys(&self, pattern: &str) -> StorageResult<Vec<String>> {
let like_pattern = pattern_to_like(pattern);
let rows: Vec<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
"SELECT key, expire_at FROM sa_token_storage WHERE key LIKE $1 ESCAPE '\\'",
)
.bind(like_pattern)
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
let mut keys = Vec::new();
for (key, expire_at) in rows {
if Self::is_expired(expire_at) {
self.delete_expired(&key).await?;
} else {
keys.push(key);
}
}
Ok(keys)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn like_pattern_escapes_wildcards() {
assert_eq!(pattern_to_like("sa:token:*"), "sa:token:%");
assert_eq!(pattern_to_like("a%b_c"), "a\\%b\\_c");
}
}
#[cfg(all(test, feature = "postgres"))]
mod postgres_tests {
use super::*;
fn database_url() -> Option<String> {
std::env::var("DATABASE_URL").ok()
}
#[tokio::test]
#[ignore = "requires PostgreSQL (set DATABASE_URL)"]
async fn database_storage_roundtrip() {
let Some(url) = database_url() else {
return;
};
let storage = DatabaseStorage::new(&url).await.expect("connect");
storage.set("sa:test:1", "v1", Some(Duration::from_secs(60))).await.unwrap();
assert_eq!(storage.get("sa:test:1").await.unwrap(), Some("v1".into()));
assert!(storage.exists("sa:test:1").await.unwrap());
let ttl = storage.ttl("sa:test:1").await.unwrap();
assert!(ttl.is_some());
let keys = storage.keys("sa:test:*").await.unwrap();
assert!(keys.contains(&"sa:test:1".to_string()));
storage.delete("sa:test:1").await.unwrap();
assert!(!storage.exists("sa:test:1").await.unwrap());
}
}