use std::collections::{HashMap, HashSet, VecDeque};
use crate::types::{CompiledExpr, CompiledRule};
use crate::{CompileError, Expr, FieldRegistry, Rule, RuleSet, Terminal};
pub(crate) fn compile(
rules: &[Rule],
mut terminals: Vec<Terminal>,
) -> Result<RuleSet, CompileError> {
check_missing_conditions(rules)?;
check_duplicates(rules)?;
check_terminals(&terminals, rules)?;
check_duplicate_terminals(&terminals)?;
let rule_map: HashMap<&str, &Rule> = rules.iter().map(|r| (r.name.as_str(), r)).collect();
check_references(rules, &rule_map)?;
let sorted_names = topological_sort(rules, &rule_map)?;
let rule_indices: HashMap<String, usize> = sorted_names
.iter()
.enumerate()
.map(|(i, name): (usize, &String)| (name.clone(), i))
.collect();
let mut field_registry = FieldRegistry::new();
for rule in rules {
collect_fields(condition_of(rule), &mut field_registry);
}
let compiled_rules: Vec<CompiledRule> = sorted_names
.iter()
.enumerate()
.map(|(i, name): (usize, &String)| {
let rule = rule_map[name.as_str()];
CompiledRule {
name: rule.name.clone(),
condition: compile_expr(condition_of(rule), &field_registry, &rule_indices),
index: i,
}
})
.collect();
terminals.sort_by_key(|t| t.priority);
let terminal_indices: Vec<usize> = terminals
.iter()
.map(|t| rule_indices[&t.rule_name])
.collect();
Ok(RuleSet {
rules: compiled_rules,
terminals,
field_registry,
terminal_indices,
})
}
fn condition_of(rule: &Rule) -> &Expr {
rule.condition
.as_ref()
.expect("condition validated by check_missing_conditions")
}
fn check_missing_conditions(rules: &[Rule]) -> Result<(), CompileError> {
for rule in rules {
if rule.condition.is_none() {
return Err(CompileError::MissingCondition {
rule: rule.name.clone(),
});
}
}
Ok(())
}
fn check_duplicates(rules: &[Rule]) -> Result<(), CompileError> {
let mut seen = HashSet::new();
for rule in rules {
if !seen.insert(&rule.name) {
return Err(CompileError::DuplicateRule {
name: rule.name.clone(),
});
}
}
Ok(())
}
fn check_duplicate_terminals(terminals: &[Terminal]) -> Result<(), CompileError> {
let mut seen = HashSet::new();
for terminal in terminals {
if !seen.insert(&terminal.rule_name) {
return Err(CompileError::DuplicateTerminal {
terminal: terminal.rule_name.clone(),
});
}
}
Ok(())
}
fn check_terminals(terminals: &[Terminal], rules: &[Rule]) -> Result<(), CompileError> {
if terminals.is_empty() {
return Err(CompileError::NoTerminals);
}
let rule_names: HashSet<&str> = rules.iter().map(|r| r.name.as_str()).collect();
for terminal in terminals {
if !rule_names.contains(terminal.rule_name.as_str()) {
return Err(CompileError::UndefinedTerminal {
terminal: terminal.rule_name.clone(),
});
}
}
Ok(())
}
fn check_references(rules: &[Rule], rule_map: &HashMap<&str, &Rule>) -> Result<(), CompileError> {
for rule in rules {
collect_and_check_refs(condition_of(rule), &rule.name, rule_map)?;
}
Ok(())
}
fn collect_and_check_refs(
expr: &Expr,
rule_name: &str,
rule_map: &HashMap<&str, &Rule>,
) -> Result<(), CompileError> {
match expr {
Expr::RuleRef(name) => {
if !rule_map.contains_key(name.as_str()) {
return Err(CompileError::UndefinedRuleRef {
rule: rule_name.to_owned(),
reference: name.clone(),
});
}
Ok(())
}
Expr::And(a, b) | Expr::Or(a, b) => {
collect_and_check_refs(a, rule_name, rule_map)?;
collect_and_check_refs(b, rule_name, rule_map)?;
Ok(())
}
Expr::Not(inner) => collect_and_check_refs(inner, rule_name, rule_map),
Expr::Compare { .. } => Ok(()),
}
}
fn topological_sort(
rules: &[Rule],
rule_map: &HashMap<&str, &Rule>,
) -> Result<Vec<String>, CompileError> {
let rule_names: HashSet<&str> = rules.iter().map(|r| r.name.as_str()).collect();
let mut dependents: HashMap<String, Vec<String>> = HashMap::new();
let mut in_degree: HashMap<String, usize> = HashMap::new();
for rule in rules {
in_degree.entry(rule.name.clone()).or_insert(0);
dependents.entry(rule.name.clone()).or_default();
}
for rule in rules {
let deps = collect_rule_refs(condition_of(rule));
for dep in deps {
if rule_names.contains(dep.as_str()) {
dependents
.entry(dep.clone())
.or_default()
.push(rule.name.clone());
*in_degree.entry(rule.name.clone()).or_insert(0) += 1;
}
}
}
let mut queue: VecDeque<String> = in_degree
.iter()
.filter(|(_, deg)| **deg == 0)
.map(|(name, _)| name.clone())
.collect();
let mut sorted = Vec::new();
while let Some(name) = queue.pop_front() {
if let Some(deps) = dependents.get(&name) {
for dependent in deps {
if let Some(deg) = in_degree.get_mut(dependent) {
*deg -= 1;
if *deg == 0 {
queue.push_back(dependent.clone());
}
}
}
}
sorted.push(name);
}
if sorted.len() != rules.len() {
let cycle = find_cycle(rules, rule_map);
return Err(CompileError::CyclicDependency { path: cycle });
}
Ok(sorted)
}
fn collect_rule_refs(expr: &Expr) -> Vec<String> {
let mut refs = Vec::new();
collect_rule_refs_inner(expr, &mut refs);
refs
}
fn collect_rule_refs_inner(expr: &Expr, refs: &mut Vec<String>) {
match expr {
Expr::RuleRef(name) => refs.push(name.clone()),
Expr::And(a, b) | Expr::Or(a, b) => {
collect_rule_refs_inner(a, refs);
collect_rule_refs_inner(b, refs);
}
Expr::Not(inner) => collect_rule_refs_inner(inner, refs),
Expr::Compare { .. } => {}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum DfsState {
Unvisited,
InStack,
Done,
}
fn find_cycle(rules: &[Rule], rule_map: &HashMap<&str, &Rule>) -> Vec<String> {
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for rule in rules {
let deps: Vec<&str> = collect_rule_refs(condition_of(rule))
.into_iter()
.filter_map(|r| rule_map.get_key_value(r.as_str()).map(|(&k, _)| k))
.collect();
adj.insert(rule.name.as_str(), deps);
}
let mut state: HashMap<&str, DfsState> = rules
.iter()
.map(|r| (r.name.as_str(), DfsState::Unvisited))
.collect();
let mut stack: Vec<&str> = Vec::new();
for rule in rules {
let name = rule.name.as_str();
if state.get(name) == Some(&DfsState::Unvisited) {
if let Some(cycle) = dfs(name, &adj, &mut state, &mut stack) {
return cycle;
}
}
}
vec![]
}
fn dfs<'a>(
node: &'a str,
adj: &HashMap<&str, Vec<&'a str>>,
state: &mut HashMap<&'a str, DfsState>,
stack: &mut Vec<&'a str>,
) -> Option<Vec<String>> {
state.insert(node, DfsState::InStack);
stack.push(node);
if let Some(neighbors) = adj.get(node) {
for &neighbor in neighbors {
match state.get(neighbor) {
Some(DfsState::InStack) => {
let pos = stack
.iter()
.position(|&n| n == neighbor)
.expect("node marked InStack must be present in stack");
let mut cycle: Vec<String> =
stack[pos..].iter().map(|&s| s.to_owned()).collect();
cycle.push(neighbor.to_owned());
return Some(cycle);
}
Some(DfsState::Unvisited) | None => {
if let Some(cycle) = dfs(neighbor, adj, state, stack) {
return Some(cycle);
}
}
Some(DfsState::Done) => {}
}
}
}
stack.pop();
state.insert(node, DfsState::Done);
None
}
fn collect_fields(expr: &Expr, registry: &mut FieldRegistry) {
match expr {
Expr::Compare { field, .. } => {
registry.register(field);
}
Expr::And(a, b) | Expr::Or(a, b) => {
collect_fields(a, registry);
collect_fields(b, registry);
}
Expr::Not(inner) => collect_fields(inner, registry),
Expr::RuleRef(_) => {}
}
}
fn compile_expr(
expr: &Expr,
field_registry: &FieldRegistry,
rule_indices: &HashMap<String, usize>,
) -> CompiledExpr {
match expr {
Expr::Compare { field, op, value } => CompiledExpr::Compare {
field_index: field_registry
.get(field)
.expect("field should be registered"),
op: *op,
value: value.clone(),
},
Expr::And(a, b) => CompiledExpr::And(
Box::new(compile_expr(a, field_registry, rule_indices)),
Box::new(compile_expr(b, field_registry, rule_indices)),
),
Expr::Or(a, b) => CompiledExpr::Or(
Box::new(compile_expr(a, field_registry, rule_indices)),
Box::new(compile_expr(b, field_registry, rule_indices)),
),
Expr::Not(inner) => {
CompiledExpr::Not(Box::new(compile_expr(inner, field_registry, rule_indices)))
}
Expr::RuleRef(name) => CompiledExpr::RuleRef(
*rule_indices
.get(name)
.expect("rule reference should be validated"),
),
}
}
#[cfg(test)]
mod tests {
use crate::{field, rule_ref, CompileError, RuleSetBuilder};
#[test]
fn compile_simple_ruleset() {
let result = RuleSetBuilder::new()
.rule("age_check", |r| r.when(field("age").gte(18_i64)))
.terminal("age_check", 0)
.compile();
assert!(result.is_ok());
let ruleset = result.unwrap();
assert_eq!(ruleset.rules.len(), 1);
assert_eq!(ruleset.rules[0].name, "age_check");
}
#[test]
fn compile_duplicate_rule() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r.when(field("x").eq(1_i64)))
.rule("r1", |r| r.when(field("y").eq(2_i64)))
.terminal("r1", 0)
.compile();
assert!(matches!(result, Err(CompileError::DuplicateRule { .. })));
}
#[test]
fn compile_no_terminals() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r.when(field("x").eq(1_i64)))
.compile();
assert!(matches!(result, Err(CompileError::NoTerminals)));
}
#[test]
fn compile_undefined_terminal() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r.when(field("x").eq(1_i64)))
.terminal("nonexistent", 0)
.compile();
assert!(matches!(
result,
Err(CompileError::UndefinedTerminal { .. })
));
}
#[test]
fn compile_undefined_rule_ref() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r.when(rule_ref("nonexistent")))
.terminal("r1", 0)
.compile();
assert!(matches!(result, Err(CompileError::UndefinedRuleRef { .. })));
}
#[test]
fn compile_duplicate_terminal() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r.when(field("x").eq(1_i64)))
.terminal("r1", 0)
.terminal("r1", 5)
.compile();
assert!(matches!(
result,
Err(CompileError::DuplicateTerminal { .. })
));
}
#[test]
fn compile_missing_condition() {
let result = RuleSetBuilder::new()
.rule("r1", |r| r)
.terminal("r1", 0)
.compile();
assert!(matches!(result, Err(CompileError::MissingCondition { .. })));
}
#[test]
fn compile_cycle_detection() {
let result = RuleSetBuilder::new()
.rule("a", |r| r.when(rule_ref("b")))
.rule("b", |r| r.when(rule_ref("a")))
.terminal("a", 0)
.compile();
assert!(matches!(result, Err(CompileError::CyclicDependency { .. })));
}
#[test]
fn compile_diamond_dependency() {
let result = RuleSetBuilder::new()
.rule("d", |r| r.when(field("x").eq(1_i64)))
.rule("b", |r| r.when(rule_ref("d")))
.rule("c", |r| r.when(rule_ref("d")))
.rule("a", |r| r.when(rule_ref("b").and(rule_ref("c"))))
.terminal("a", 0)
.compile();
assert!(result.is_ok());
}
#[test]
fn topo_sort_dependencies_before_dependents() {
let ruleset = RuleSetBuilder::new()
.rule("leaf", |r| r.when(field("x").eq(1_i64)))
.rule("mid", |r| r.when(rule_ref("leaf")))
.rule("top", |r| r.when(rule_ref("mid")))
.terminal("top", 0)
.compile()
.unwrap();
let index_of =
|name: &str| -> usize { ruleset.rules.iter().find(|r| r.name == name).unwrap().index };
let leaf_idx = index_of("leaf");
let mid_idx = index_of("mid");
let top_idx = index_of("top");
assert!(leaf_idx < mid_idx);
assert!(mid_idx < top_idx);
}
#[test]
fn terminals_sorted_by_priority() {
let ruleset = RuleSetBuilder::new()
.rule("r1", |r| r.when(field("x").eq(1_i64)))
.rule("r2", |r| r.when(field("y").eq(2_i64)))
.terminal("r2", 10)
.terminal("r1", 0)
.compile()
.unwrap();
assert_eq!(ruleset.terminals[0].rule_name, "r1");
assert_eq!(ruleset.terminals[0].priority, 0);
assert_eq!(ruleset.terminals[1].rule_name, "r2");
assert_eq!(ruleset.terminals[1].priority, 10);
}
#[test]
fn compile_three_node_cycle() {
let result = RuleSetBuilder::new()
.rule("a", |r| r.when(rule_ref("b")))
.rule("b", |r| r.when(rule_ref("c")))
.rule("c", |r| r.when(rule_ref("a")))
.terminal("a", 0)
.compile();
match result {
Err(CompileError::CyclicDependency { path }) => {
assert!(path.len() >= 3, "cycle path should have at least 3 nodes");
assert_eq!(path.first(), path.last());
}
other => panic!("expected CyclicDependency, got {other:?}"),
}
}
}