use std::time::{Duration, Instant};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use once_cell::sync::Lazy;
use sqlx::Row as _;
use crate::error::Result;
use crate::orm::Db;
const FLAG_CACHE_TTL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Copy)]
struct CacheEntry {
enabled: bool,
expires: Instant,
}
static FLAG_CACHE: Lazy<DashMap<String, CacheEntry>> = Lazy::new(DashMap::new);
pub(crate) const CREATE_TABLE_SQL: &str = "CREATE TABLE IF NOT EXISTS rustio_feature_flags (
key TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT FALSE,
description TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)";
pub async fn ensure_table(db: &Db) -> Result<()> {
sqlx::query(CREATE_TABLE_SQL).execute(db.pool()).await?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct FeatureFlag {
pub key: String,
pub enabled: bool,
pub description: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
pub async fn feature_enabled(db: &Db, key: &str) -> bool {
if let Some(entry) = FLAG_CACHE.get(key) {
if entry.expires > Instant::now() {
return entry.enabled;
}
}
let enabled: Option<bool> =
sqlx::query_scalar("SELECT enabled FROM rustio_feature_flags WHERE key = $1")
.bind(key)
.fetch_optional(db.pool())
.await
.ok()
.flatten();
let enabled = enabled.unwrap_or(false);
FLAG_CACHE.insert(
key.to_string(),
CacheEntry {
enabled,
expires: Instant::now() + FLAG_CACHE_TTL,
},
);
enabled
}
pub fn invalidate_cache() {
FLAG_CACHE.clear();
}
pub(crate) async fn list_flags(db: &Db) -> Result<Vec<FeatureFlag>> {
ensure_table(db).await?;
let rows = sqlx::query(
"SELECT key, enabled, description, created_at, updated_at \
FROM rustio_feature_flags ORDER BY created_at DESC",
)
.fetch_all(db.pool())
.await?;
let out = rows
.iter()
.map(|r| FeatureFlag {
key: r.try_get("key").unwrap_or_default(),
enabled: r.try_get("enabled").unwrap_or(false),
description: r.try_get("description").unwrap_or_default(),
created_at: r.try_get("created_at").unwrap_or_else(|_| Utc::now()),
updated_at: r.try_get("updated_at").unwrap_or_else(|_| Utc::now()),
})
.collect();
Ok(out)
}
pub(crate) async fn create_flag(db: &Db, key: &str, description: &str) -> Result<()> {
ensure_table(db).await?;
sqlx::query(
"INSERT INTO rustio_feature_flags (key, enabled, description) \
VALUES ($1, FALSE, $2) ON CONFLICT (key) DO NOTHING",
)
.bind(key)
.bind(description)
.execute(db.pool())
.await?;
invalidate_cache();
Ok(())
}
pub(crate) async fn set_flag(db: &Db, key: &str, enabled: bool) -> Result<()> {
ensure_table(db).await?;
sqlx::query(
"UPDATE rustio_feature_flags \
SET enabled = $1, updated_at = NOW() \
WHERE key = $2",
)
.bind(enabled)
.bind(key)
.execute(db.pool())
.await?;
invalidate_cache();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flag_cache_ttl_is_60_seconds() {
assert_eq!(FLAG_CACHE_TTL, Duration::from_secs(60));
}
#[test]
fn invalidate_cache_clears_every_entry() {
FLAG_CACHE.insert(
"key_a".into(),
CacheEntry {
enabled: true,
expires: Instant::now() + Duration::from_secs(60),
},
);
FLAG_CACHE.insert(
"key_b".into(),
CacheEntry {
enabled: false,
expires: Instant::now() + Duration::from_secs(60),
},
);
invalidate_cache();
assert!(FLAG_CACHE.get("key_a").is_none());
assert!(FLAG_CACHE.get("key_b").is_none());
}
}