frigg 0.3.2

Local-first MCP server for code understanding.
Documentation
use super::kernel::PolicyProgram;
use super::trace::{PolicyEffect, PolicyStage};

pub(crate) type PredicateFn<Ctx> = fn(&Ctx) -> bool;
pub(crate) type RuleFn<Ctx> = fn(&Ctx) -> Option<PolicyEffect>;

#[derive(Clone, Copy)]
pub(crate) struct PredicateLeaf<Ctx: 'static> {
    pub id: &'static str,
    pub eval: PredicateFn<Ctx>,
}

impl<Ctx: 'static> PredicateLeaf<Ctx> {
    pub(crate) const fn new(id: &'static str, eval: PredicateFn<Ctx>) -> Self {
        Self { id, eval }
    }
}

#[derive(Clone, Copy)]
pub(crate) struct Predicate<Ctx: 'static> {
    pub all: &'static [PredicateLeaf<Ctx>],
    pub any: &'static [PredicateLeaf<Ctx>],
    pub none: &'static [PredicateLeaf<Ctx>],
}

impl<Ctx: 'static> Predicate<Ctx> {
    pub(crate) const ALWAYS: Self = Self {
        all: &[],
        any: &[],
        none: &[],
    };

    pub(crate) const fn new(
        all: &'static [PredicateLeaf<Ctx>],
        any: &'static [PredicateLeaf<Ctx>],
        none: &'static [PredicateLeaf<Ctx>],
    ) -> Self {
        Self { all, any, none }
    }

    pub(crate) const fn all(all: &'static [PredicateLeaf<Ctx>]) -> Self {
        Self::new(all, &[], &[])
    }

    pub(crate) const fn any(any: &'static [PredicateLeaf<Ctx>]) -> Self {
        Self::new(&[], any, &[])
    }
}

#[derive(Clone, Copy)]
pub(crate) struct ScoreRule<Ctx: 'static> {
    pub id: &'static str,
    pub stage: PolicyStage,
    pub when: Predicate<Ctx>,
    pub eval: RuleFn<Ctx>,
}

impl<Ctx: 'static> ScoreRule<Ctx> {
    pub(crate) const fn when(
        id: &'static str,
        stage: PolicyStage,
        when: Predicate<Ctx>,
        eval: RuleFn<Ctx>,
    ) -> Self {
        Self {
            id,
            stage,
            when,
            eval,
        }
    }
}

#[derive(Clone, Copy)]
pub(crate) struct ScoreRuleSet<Ctx: 'static> {
    pub rules: &'static [ScoreRule<Ctx>],
}

impl<Ctx: 'static> ScoreRuleSet<Ctx> {
    pub(crate) const fn new(rules: &'static [ScoreRule<Ctx>]) -> Self {
        Self { rules }
    }
}

fn matched_predicate_ids<Ctx: 'static>(
    ctx: &Ctx,
    predicate: &Predicate<Ctx>,
) -> Option<Vec<&'static str>> {
    let mut matched = Vec::new();

    for leaf in predicate.all {
        if !(leaf.eval)(ctx) {
            return None;
        }
        matched.push(leaf.id);
    }

    if !predicate.any.is_empty() {
        let any_matches = predicate
            .any
            .iter()
            .filter_map(|leaf| (leaf.eval)(ctx).then_some(leaf.id))
            .collect::<Vec<_>>();
        if any_matches.is_empty() {
            return None;
        }
        matched.extend(any_matches);
    }

    if predicate.none.iter().any(|leaf| (leaf.eval)(ctx)) {
        return None;
    }

    Some(matched)
}

pub(crate) fn predicate_matches<Ctx: 'static>(ctx: &Ctx, predicate: Predicate<Ctx>) -> bool {
    matched_predicate_ids(ctx, &predicate).is_some()
}

pub(crate) fn apply_score_rules<Ctx: 'static>(
    program: &mut PolicyProgram,
    ctx: &Ctx,
    rules: &[ScoreRule<Ctx>],
) -> bool {
    let mut applied = false;
    for rule in rules {
        let Some(predicate_ids) = matched_predicate_ids(ctx, &rule.when) else {
            continue;
        };
        if let Some(effect) = (rule.eval)(ctx) {
            applied = true;
            program.apply_effect(rule.id, rule.stage, &predicate_ids, effect);
        }
    }
    applied
}

pub(crate) fn apply_score_rule_sets<Ctx: 'static>(
    program: &mut PolicyProgram,
    ctx: &Ctx,
    rule_sets: &[ScoreRuleSet<Ctx>],
) -> bool {
    let mut applied = false;
    for rule_set in rule_sets {
        applied |= apply_score_rules(program, ctx, rule_set.rules);
    }
    applied
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Clone, Copy)]
    struct TestCtx {
        enabled: bool,
        multiplier: f32,
    }

    fn gate(ctx: &TestCtx) -> bool {
        ctx.enabled
    }

    const TEST_ENABLED: PredicateLeaf<TestCtx> = PredicateLeaf::new("test.enabled", gate);
    const TEST_GATE: PredicateLeaf<TestCtx> = PredicateLeaf::new("test.gate", gate);
    const TEST_ENABLED_PREDICATES: &[PredicateLeaf<TestCtx>] = &[TEST_ENABLED];
    const TEST_GATE_PREDICATES: &[PredicateLeaf<TestCtx>] = &[TEST_GATE];

    fn add_rule(ctx: &TestCtx) -> Option<PolicyEffect> {
        ctx.enabled.then_some(PolicyEffect::Add(2.0))
    }

    fn mul_rule(ctx: &TestCtx) -> Option<PolicyEffect> {
        ctx.enabled
            .then_some(PolicyEffect::Multiply(ctx.multiplier))
    }

    #[test]
    fn policy_rules_apply_in_declared_order() {
        let ctx = TestCtx {
            enabled: true,
            multiplier: 3.0,
        };
        let rules = &[
            ScoreRule::when(
                "test.add",
                PolicyStage::PathQuality,
                Predicate::all(TEST_ENABLED_PREDICATES),
                add_rule,
            ),
            ScoreRule::when(
                "test.mul",
                PolicyStage::PathQuality,
                Predicate::ALWAYS,
                mul_rule,
            ),
        ];
        let mut program = PolicyProgram::with_trace(1.0);
        apply_score_rules(&mut program, &ctx, rules);
        let evaluation = program.finish();
        let trace = evaluation.trace.expect("trace");
        assert_eq!(evaluation.score, 9.0);
        assert_eq!(trace.rules[0].predicate_ids, vec!["test.enabled"]);
        assert!(trace.rules[1].predicate_ids.is_empty());
    }

    #[test]
    fn policy_rules_predicates_are_deterministic() {
        let predicate = Predicate::all(TEST_GATE_PREDICATES);
        assert!(predicate_matches(
            &TestCtx {
                enabled: true,
                multiplier: 1.0,
            },
            predicate,
        ));
        assert!(!predicate_matches(
            &TestCtx {
                enabled: false,
                multiplier: 1.0,
            },
            predicate,
        ));
    }
}