use async_trait::async_trait;
use dashmap::DashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
pub fn calculate_hash_bucket(flag: &str, identifier: &str) -> u32 {
let mut hasher = DefaultHasher::new();
hasher.write(flag.as_bytes());
hasher.write(identifier.as_bytes());
let hash_val = hasher.finish();
(hash_val % 100) as u32
}
pub fn parse_rollout(s: &str) -> Option<u32> {
let cleaned = s.trim().trim_end_matches('%');
cleaned.parse::<u32>().ok()
}
pub fn parse_variants(s: &str) -> Vec<(String, u32)> {
let mut parsed = Vec::new();
for part in s.split(',') {
let mut split = part.split(':');
if let (Some(name), Some(pct_str)) = (split.next(), split.next())
&& let Ok(pct) = pct_str.trim().parse::<u32>()
{
parsed.push((name.trim().to_string(), pct));
}
}
parsed
}
pub fn resolve_variant(variants: &[(String, u32)], bucket: u32) -> Option<String> {
let mut accumulator = 0;
for (name, pct) in variants {
accumulator += pct;
if bucket < accumulator {
return Some(name.clone());
}
}
None
}
#[async_trait]
pub trait FeatureDriver: Send + Sync {
async fn enabled(&self, flag: &str) -> Option<bool>;
async fn enabled_for(&self, flag: &str, identifier: &str) -> Option<bool>;
async fn variant(&self, flag: &str, identifier: &str) -> Option<String>;
}
struct MemoryFlagRule {
enabled: bool,
rollout_percentage: Option<u32>,
variants: Option<Vec<(String, u32)>>,
}
#[non_exhaustive]
pub struct MemoryFeatureDriver {
rules: DashMap<String, MemoryFlagRule>,
}
impl MemoryFeatureDriver {
pub fn new() -> Self {
Self {
rules: DashMap::new(),
}
}
pub fn override_enabled(&self, flag: &str, enabled: bool) {
self.rules.insert(
flag.to_string(),
MemoryFlagRule {
enabled,
rollout_percentage: None,
variants: None,
},
);
}
pub fn override_rollout(&self, flag: &str, percentage: u32) {
self.rules.insert(
flag.to_string(),
MemoryFlagRule {
enabled: true,
rollout_percentage: Some(percentage),
variants: None,
},
);
}
pub fn override_variants(&self, flag: &str, variants: Vec<(String, u32)>) {
self.rules.insert(
flag.to_string(),
MemoryFlagRule {
enabled: true,
rollout_percentage: None,
variants: Some(variants),
},
);
}
}
impl Default for MemoryFeatureDriver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl FeatureDriver for MemoryFeatureDriver {
async fn enabled(&self, flag: &str) -> Option<bool> {
self.rules
.get(flag)
.map(|r| r.enabled && r.rollout_percentage.is_none())
}
async fn enabled_for(&self, flag: &str, identifier: &str) -> Option<bool> {
let rule = self.rules.get(flag)?;
if !rule.enabled {
return Some(false);
}
if let Some(pct) = rule.rollout_percentage {
let bucket = calculate_hash_bucket(flag, identifier);
return Some(bucket < pct);
}
Some(rule.enabled)
}
async fn variant(&self, flag: &str, identifier: &str) -> Option<String> {
let rule = self.rules.get(flag)?;
if !rule.enabled {
return Some("disabled".to_string());
}
if let Some(ref variants) = rule.variants {
let bucket = calculate_hash_bucket(flag, identifier);
return resolve_variant(variants, bucket);
}
if let Some(pct) = rule.rollout_percentage {
let bucket = calculate_hash_bucket(flag, identifier);
return Some(if bucket < pct {
"enabled".to_string()
} else {
"disabled".to_string()
});
}
Some(if rule.enabled {
"enabled".to_string()
} else {
"disabled".to_string()
})
}
}
#[non_exhaustive]
pub struct EnvFeatureDriver;
impl EnvFeatureDriver {
pub fn new() -> Self {
Self
}
fn env_key(flag: &str) -> String {
format!("FEATURE_{}", flag.to_uppercase().replace('-', "_"))
}
fn parse_env_value(&self, value: &str, flag: &str, identifier: Option<&str>) -> Option<String> {
let cleaned = value.trim();
if cleaned.is_empty() {
return None;
}
if cleaned == "true" || cleaned == "1" || cleaned == "yes" {
return Some("enabled".to_string());
}
if cleaned == "false" || cleaned == "0" || cleaned == "no" {
return Some("disabled".to_string());
}
if cleaned.ends_with('%')
&& let Some(pct) = parse_rollout(cleaned)
{
if let Some(ident) = identifier {
let bucket = calculate_hash_bucket(flag, ident);
return Some(if bucket < pct {
"enabled".to_string()
} else {
"disabled".to_string()
});
}
return Some("disabled".to_string());
}
if cleaned.contains(':') {
let variants = parse_variants(cleaned);
if !variants.is_empty()
&& let Some(ident) = identifier
{
let bucket = calculate_hash_bucket(flag, ident);
return resolve_variant(&variants, bucket);
}
}
Some(cleaned.to_string())
}
}
impl Default for EnvFeatureDriver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl FeatureDriver for EnvFeatureDriver {
async fn enabled(&self, flag: &str) -> Option<bool> {
let key = Self::env_key(flag);
let val = std::env::var(key).ok()?;
let parsed = self.parse_env_value(&val, flag, None)?;
Some(parsed == "enabled")
}
async fn enabled_for(&self, flag: &str, identifier: &str) -> Option<bool> {
let key = Self::env_key(flag);
let val = std::env::var(key).ok()?;
let parsed = self.parse_env_value(&val, flag, Some(identifier))?;
Some(parsed == "enabled")
}
async fn variant(&self, flag: &str, identifier: &str) -> Option<String> {
let key = Self::env_key(flag);
let val = std::env::var(key).ok()?;
self.parse_env_value(&val, flag, Some(identifier))
}
}
#[non_exhaustive]
pub struct TomlFeatureDriver {
config: DashMap<String, String>,
config_path: std::path::PathBuf,
}
impl TomlFeatureDriver {
pub fn new() -> Self {
let config_path = std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("."))
.join("Rullst.toml");
let driver = Self {
config: DashMap::new(),
config_path,
};
if let Ok(content) = std::fs::read_to_string(&driver.config_path) {
driver.load_from_str(&content);
}
driver
}
pub async fn reload(&self) -> Result<(), Box<dyn std::error::Error>> {
let content = tokio::fs::read_to_string(&self.config_path).await?;
self.load_from_str(&content);
Ok(())
}
fn load_from_str(&self, content: &str) {
self.config.clear();
let mut in_features = false;
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if trimmed == "[features]" {
in_features = true;
continue;
}
if trimmed.starts_with('[') {
in_features = false;
continue;
}
if in_features {
let mut parts = trimmed.splitn(2, '=');
if let (Some(key), Some(val)) = (parts.next(), parts.next()) {
let k = key.trim().to_string();
let clean_val = val.split('#').next().unwrap_or(val).trim();
let v = clean_val.trim_matches('"').trim_matches('\'').to_string();
self.config.insert(k, v);
}
}
}
}
fn evaluate(&self, value: &str, flag: &str, identifier: Option<&str>) -> Option<String> {
let cleaned = value.trim();
if cleaned == "true" || cleaned == "1" || cleaned == "yes" {
return Some("enabled".to_string());
}
if cleaned == "false" || cleaned == "0" || cleaned == "no" {
return Some("disabled".to_string());
}
if cleaned.ends_with('%')
&& let Some(pct) = parse_rollout(cleaned)
{
if let Some(ident) = identifier {
let bucket = calculate_hash_bucket(flag, ident);
return Some(if bucket < pct {
"enabled".to_string()
} else {
"disabled".to_string()
});
}
return Some("disabled".to_string());
}
if cleaned.contains(':') {
let variants = parse_variants(cleaned);
if !variants.is_empty()
&& let Some(ident) = identifier
{
let bucket = calculate_hash_bucket(flag, ident);
return resolve_variant(&variants, bucket);
}
}
Some(cleaned.to_string())
}
}
impl Default for TomlFeatureDriver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl FeatureDriver for TomlFeatureDriver {
async fn enabled(&self, flag: &str) -> Option<bool> {
let val = self.config.get(flag)?;
let evaluated = self.evaluate(val.value(), flag, None)?;
Some(evaluated == "enabled")
}
async fn enabled_for(&self, flag: &str, identifier: &str) -> Option<bool> {
let val = self.config.get(flag)?;
let evaluated = self.evaluate(val.value(), flag, Some(identifier))?;
Some(evaluated == "enabled")
}
async fn variant(&self, flag: &str, identifier: &str) -> Option<String> {
let val = self.config.get(flag)?;
self.evaluate(val.value(), flag, Some(identifier))
}
}
struct DbCacheValue {
enabled: bool,
rollout_percentage: Option<u32>,
variants: Option<String>,
expires_at: Instant,
}
#[non_exhaustive]
pub struct DbFeatureDriver {
cache: DashMap<String, DbCacheValue>,
ttl: Duration,
}
impl DbFeatureDriver {
pub fn new() -> Self {
Self {
cache: DashMap::new(),
ttl: Duration::from_secs(5),
}
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
cache: DashMap::new(),
ttl,
}
}
async fn fetch_flag_from_db(&self, flag: &str) -> Option<(bool, Option<u32>, Option<String>)> {
use sqlx::Row;
let pool = crate::db::safe_pool()?;
let row = sqlx::query(
"SELECT enabled, rollout_percentage, variants FROM rullst_feature_flags WHERE name = ?",
)
.bind(flag)
.fetch_optional(pool)
.await
.ok()
.flatten()?;
let enabled = row
.try_get::<i32, _>("enabled")
.map(|v| v != 0)
.or_else(|_| row.try_get::<bool, _>("enabled"))
.unwrap_or(false);
let rollout_percentage = row
.try_get::<i32, _>("rollout_percentage")
.map(|v| Some(v as u32))
.unwrap_or(None);
let variants = row.try_get::<String, _>("variants").ok();
Some((enabled, rollout_percentage, variants))
}
async fn resolve_flag(&self, flag: &str) -> Option<(bool, Option<u32>, Option<String>)> {
if let Some(entry) = self.cache.get(flag)
&& Instant::now() < entry.expires_at
{
return Some((
entry.enabled,
entry.rollout_percentage,
entry.variants.clone(),
));
}
let (enabled, rollout, variants) = self.fetch_flag_from_db(flag).await?;
self.cache.insert(
flag.to_string(),
DbCacheValue {
enabled,
rollout_percentage: rollout,
variants: variants.clone(),
expires_at: Instant::now() + self.ttl,
},
);
Some((enabled, rollout, variants))
}
fn evaluate(
&self,
enabled: bool,
rollout: Option<u32>,
variants: Option<String>,
flag: &str,
identifier: Option<&str>,
) -> Option<String> {
if !enabled {
return Some("disabled".to_string());
}
if let Some(vars_str) = variants {
let vars = parse_variants(&vars_str);
if !vars.is_empty()
&& let Some(ident) = identifier
{
let bucket = calculate_hash_bucket(flag, ident);
return resolve_variant(&vars, bucket);
}
}
if let Some(pct) = rollout {
if let Some(ident) = identifier {
let bucket = calculate_hash_bucket(flag, ident);
return Some(if bucket < pct {
"enabled".to_string()
} else {
"disabled".to_string()
});
}
return Some("disabled".to_string());
}
Some(if enabled {
"enabled".to_string()
} else {
"disabled".to_string()
})
}
}
impl Default for DbFeatureDriver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl FeatureDriver for DbFeatureDriver {
async fn enabled(&self, flag: &str) -> Option<bool> {
let (enabled, rollout, variants) = self.resolve_flag(flag).await?;
let evaluated = self.evaluate(enabled, rollout, variants, flag, None)?;
Some(evaluated == "enabled")
}
async fn enabled_for(&self, flag: &str, identifier: &str) -> Option<bool> {
let (enabled, rollout, variants) = self.resolve_flag(flag).await?;
let evaluated = self.evaluate(enabled, rollout, variants, flag, Some(identifier))?;
Some(evaluated == "enabled")
}
async fn variant(&self, flag: &str, identifier: &str) -> Option<String> {
let (enabled, rollout, variants) = self.resolve_flag(flag).await?;
self.evaluate(enabled, rollout, variants, flag, Some(identifier))
}
}
#[non_exhaustive]
pub struct FeatureManager {
drivers: Vec<Box<dyn FeatureDriver>>,
}
impl FeatureManager {
pub fn new() -> Self {
Self {
drivers: Vec::new(),
}
}
pub fn add_driver(mut self, driver: Box<dyn FeatureDriver>) -> Self {
self.drivers.push(driver);
self
}
pub async fn enabled(&self, flag: &str) -> bool {
for driver in &self.drivers {
if let Some(val) = driver.enabled(flag).await {
return val;
}
}
false
}
pub async fn enabled_for(&self, flag: &str, identifier: &str) -> bool {
for driver in &self.drivers {
if let Some(val) = driver.enabled_for(flag, identifier).await {
return val;
}
}
false
}
pub async fn variant(&self, flag: &str, identifier: &str) -> Option<String> {
for driver in &self.drivers {
if let Some(val) = driver.variant(flag, identifier).await {
return Some(val);
}
}
None
}
}
impl Default for FeatureManager {
fn default() -> Self {
Self::new()
.add_driver(Box::new(MemoryFeatureDriver::new()))
.add_driver(Box::new(EnvFeatureDriver::new()))
.add_driver(Box::new(TomlFeatureDriver::new()))
.add_driver(Box::new(DbFeatureDriver::new()))
}
}
static FEATURE_CELL: OnceLock<FeatureManager> = OnceLock::new();
pub fn init(manager: FeatureManager) -> Result<(), FeatureManager> {
FEATURE_CELL.set(manager)
}
pub fn manager() -> &'static FeatureManager {
FEATURE_CELL.get_or_init(FeatureManager::default)
}
pub async fn enabled(flag: &str) -> bool {
manager().enabled(flag).await
}
pub async fn enabled_for(flag: &str, identifier: &str) -> bool {
manager().enabled_for(flag, identifier).await
}
pub async fn variant(flag: &str, identifier: &str) -> Option<String> {
manager().variant(flag, identifier).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_hash_bucket() {
let b1 = calculate_hash_bucket("flag-a", "user-1");
let b2 = calculate_hash_bucket("flag-a", "user-1");
let b3 = calculate_hash_bucket("flag-a", "user-2");
assert_eq!(b1, b2);
assert!(b1 < 100);
assert!(b3 < 100);
}
#[test]
fn test_parse_rollout() {
assert_eq!(parse_rollout("30%"), Some(30));
assert_eq!(parse_rollout(" 100% "), Some(100));
assert_eq!(parse_rollout("0"), Some(0));
assert_eq!(parse_rollout("abc"), None);
}
#[test]
fn test_parse_variants() {
let parsed = parse_variants("control:50,treatment:50");
assert_eq!(parsed.len(), 2);
assert_eq!(parsed[0].0, "control");
assert_eq!(parsed[0].1, 50);
assert_eq!(parsed[1].0, "treatment");
assert_eq!(parsed[1].1, 50);
let parsed_empty = parse_variants("invalid");
assert!(parsed_empty.is_empty());
}
#[test]
fn test_resolve_variant() {
let variants = vec![("control".to_string(), 30), ("treatment".to_string(), 70)];
assert_eq!(resolve_variant(&variants, 10), Some("control".to_string()));
assert_eq!(
resolve_variant(&variants, 50),
Some("treatment".to_string())
);
assert_eq!(resolve_variant(&variants, 101), None);
}
}