use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::error::FeatureFlagError;
use crate::predicate::Predicate;
pub type Variant = String;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Rule {
pub id: String,
pub when: Predicate,
pub outcome: Outcome,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Outcome {
Variant {
variant: Variant,
},
Rollout {
variants: Vec<RolloutEntry>,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RolloutEntry {
pub variant: Variant,
pub weight: u32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Flag {
pub id: String,
#[serde(default)]
pub description: Option<String>,
pub variants: Vec<Variant>,
pub default_variant: Variant,
#[serde(default)]
pub rules: Vec<Rule>,
#[serde(default = "default_true")]
pub enabled: bool,
}
fn default_true() -> bool {
true
}
impl Flag {
pub fn validate(&self) -> Result<(), FeatureFlagError> {
if self.id.is_empty() {
return Err(FeatureFlagError::Invalid(
"flag.id must be non-empty".into(),
));
}
if self.variants.is_empty() {
return Err(FeatureFlagError::Invalid(format!(
"flag {}: variants must be non-empty",
self.id
)));
}
let known: HashSet<&str> = self.variants.iter().map(String::as_str).collect();
if !known.contains(self.default_variant.as_str()) {
return Err(FeatureFlagError::Invalid(format!(
"flag {}: default_variant {:?} not in variants",
self.id, self.default_variant
)));
}
for rule in &self.rules {
match &rule.outcome {
Outcome::Variant { variant } => {
if !known.contains(variant.as_str()) {
return Err(FeatureFlagError::Invalid(format!(
"flag {}: rule {:?} references unknown variant {:?}",
self.id, rule.id, variant
)));
}
}
Outcome::Rollout { variants } => {
let total: u32 = variants.iter().map(|e| e.weight).sum();
if total != 100 {
return Err(FeatureFlagError::Invalid(format!(
"flag {}: rule {:?} rollout weights total {} (must be 100)",
self.id, rule.id, total
)));
}
for entry in variants {
if !known.contains(entry.variant.as_str()) {
return Err(FeatureFlagError::Invalid(format!(
"flag {}: rule {:?} references unknown variant {:?}",
self.id, rule.id, entry.variant
)));
}
}
}
}
}
Ok(())
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct FlagSet {
#[serde(default)]
pub version: String,
pub flags: Vec<Flag>,
}
impl FlagSet {
pub fn from_json(raw: &str) -> Result<Self, FeatureFlagError> {
let parsed: Self = serde_json::from_str(raw)?;
parsed.validate()?;
Ok(parsed)
}
pub fn validate(&self) -> Result<(), FeatureFlagError> {
let mut seen: HashSet<&str> = HashSet::new();
for flag in &self.flags {
if !seen.insert(flag.id.as_str()) {
return Err(FeatureFlagError::Invalid(format!(
"duplicate flag id: {}",
flag.id
)));
}
flag.validate()?;
}
Ok(())
}
#[must_use]
pub fn index(&self) -> HashMap<&str, &Flag> {
self.flags.iter().map(|f| (f.id.as_str(), f)).collect()
}
}