feature-flag 0.1.0

Server-side feature flag evaluation for async Rust: targeting rules, sticky percentage rollouts, hot reload, zero RNG.
Documentation
//! Flag / Rule / FlagSet types.

use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};

use crate::error::FeatureFlagError;
use crate::predicate::Predicate;

/// One of a flag's named outcomes (`"on"`, `"off"`, `"experiment-a"`, ...).
pub type Variant = String;

/// A flag's targeting rule. Evaluated in declared order; first match wins.
///
/// The matching rule yields either:
///
/// - a fixed `variant` (e.g. `"on"`), or
/// - a `rollout` mapping subject → variant via sticky bucketing.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Rule {
    /// Stable identifier for telemetry.
    pub id: String,
    /// Predicate that must match for this rule to fire.
    pub when: Predicate,
    /// Outcome when this rule fires.
    pub outcome: Outcome,
    /// Free-form description for operators.
    #[serde(default)]
    pub description: Option<String>,
}

/// What a matching rule produces.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Outcome {
    /// Use a single variant unconditionally.
    Variant {
        /// Variant name.
        variant: Variant,
    },
    /// Bucket the subject's id into one of these variants.
    /// Weights must sum to 100. Bucketing is sticky (SHA-256 mod 100).
    Rollout {
        /// `(variant, weight)` pairs. Weights are percentages and must sum to 100.
        variants: Vec<RolloutEntry>,
    },
}

/// One slot in a `Rollout` outcome.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RolloutEntry {
    /// Variant name.
    pub variant: Variant,
    /// Weight in percentage points (0..=100). Sum of all entries must be 100.
    pub weight: u32,
}

/// A single flag.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Flag {
    /// Stable flag identifier (also the public name).
    pub id: String,
    /// Free-form description.
    #[serde(default)]
    pub description: Option<String>,
    /// All possible variant names. Rules + the default must reference values
    /// from this list.
    pub variants: Vec<Variant>,
    /// Variant returned when no rule matches.
    pub default_variant: Variant,
    /// Ordered targeting rules.
    #[serde(default)]
    pub rules: Vec<Rule>,
    /// When `false`, the evaluator skips all rules and returns
    /// `default_variant`. Useful as a kill switch.
    #[serde(default = "default_true")]
    pub enabled: bool,
}

fn default_true() -> bool {
    true
}

impl Flag {
    /// Validate cross-field invariants. Called by [`FlagSet::validate`].
    pub fn validate(&self) -> Result<(), FeatureFlagError> {
        if self.id.is_empty() {
            return Err(FeatureFlagError::Invalid(
                "flag.id must be non-empty".into(),
            ));
        }
        if self.variants.is_empty() {
            return Err(FeatureFlagError::Invalid(format!(
                "flag {}: variants must be non-empty",
                self.id
            )));
        }
        let known: HashSet<&str> = self.variants.iter().map(String::as_str).collect();
        if !known.contains(self.default_variant.as_str()) {
            return Err(FeatureFlagError::Invalid(format!(
                "flag {}: default_variant {:?} not in variants",
                self.id, self.default_variant
            )));
        }
        for rule in &self.rules {
            match &rule.outcome {
                Outcome::Variant { variant } => {
                    if !known.contains(variant.as_str()) {
                        return Err(FeatureFlagError::Invalid(format!(
                            "flag {}: rule {:?} references unknown variant {:?}",
                            self.id, rule.id, variant
                        )));
                    }
                }
                Outcome::Rollout { variants } => {
                    let total: u32 = variants.iter().map(|e| e.weight).sum();
                    if total != 100 {
                        return Err(FeatureFlagError::Invalid(format!(
                            "flag {}: rule {:?} rollout weights total {} (must be 100)",
                            self.id, rule.id, total
                        )));
                    }
                    for entry in variants {
                        if !known.contains(entry.variant.as_str()) {
                            return Err(FeatureFlagError::Invalid(format!(
                                "flag {}: rule {:?} references unknown variant {:?}",
                                self.id, rule.id, entry.variant
                            )));
                        }
                    }
                }
            }
        }
        Ok(())
    }
}

/// A collection of flags loaded from JSON. Construct via `FlagSet::from_json`
/// or build in code for tests.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct FlagSet {
    /// Bundle version. Free-form; logged by [`crate::FlagEvaluator`].
    #[serde(default)]
    pub version: String,
    /// The flags.
    pub flags: Vec<Flag>,
}

impl FlagSet {
    /// Parse a JSON document.
    pub fn from_json(raw: &str) -> Result<Self, FeatureFlagError> {
        let parsed: Self = serde_json::from_str(raw)?;
        parsed.validate()?;
        Ok(parsed)
    }

    /// Check every flag's invariants and detect duplicate flag ids.
    pub fn validate(&self) -> Result<(), FeatureFlagError> {
        let mut seen: HashSet<&str> = HashSet::new();
        for flag in &self.flags {
            if !seen.insert(flag.id.as_str()) {
                return Err(FeatureFlagError::Invalid(format!(
                    "duplicate flag id: {}",
                    flag.id
                )));
            }
            flag.validate()?;
        }
        Ok(())
    }

    /// Build a `HashMap<flag_id, &Flag>` for O(1) lookup. Used internally by
    /// [`crate::FlagEvaluator`].
    #[must_use]
    pub fn index(&self) -> HashMap<&str, &Flag> {
        self.flags.iter().map(|f| (f.id.as_str(), f)).collect()
    }
}