lemma/
analysis.rs

1//! Reference and dependency analysis utilities
2//!
3//! This module provides shared infrastructure for extracting references
4//! from expressions and analyzing dependencies between rules.
5//!
6//! Used by both semantic validation and evaluation.
7
8use crate::{Expression, ExpressionKind, FactType, FactValue, LemmaFact, LemmaRule, RuleResult};
9use std::collections::{HashMap, HashSet};
10
11/// References extracted from an expression
12#[derive(Debug, Clone, Default)]
13pub struct References {
14    /// Fact references (e.g., ["employee", "name"])
15    pub facts: HashSet<Vec<String>>,
16    /// Rule references (e.g., ["employee", "is_eligible"])
17    pub rules: HashSet<Vec<String>>,
18}
19
20/// Extract all fact and rule references from an expression.
21///
22/// Recursively walks the expression tree to find all references to facts and rules.
23/// Useful for dependency analysis and validation.
24///
25/// # Examples
26/// ```text
27/// Expression: price * quantity
28/// Returns: facts = ["price", "quantity"], rules = []
29///
30/// Expression: base_amount + adjustment?
31/// Returns: facts = ["base_amount"], rules = ["adjustment"]
32/// ```
33pub fn extract_references(expr: &Expression) -> References {
34    let mut refs = References::default();
35    collect_references(expr, &mut refs.facts, &mut refs.rules);
36    refs
37}
38
39/// Recursively collect all fact and rule references from an expression
40fn collect_references(
41    expr: &Expression,
42    fact_refs: &mut HashSet<Vec<String>>,
43    rule_refs: &mut HashSet<Vec<String>>,
44) {
45    match &expr.kind {
46        ExpressionKind::FactReference(fact_ref) => {
47            fact_refs.insert(fact_ref.reference.clone());
48        }
49        ExpressionKind::RuleReference(rule_ref) => {
50            rule_refs.insert(rule_ref.reference.clone());
51        }
52        ExpressionKind::Arithmetic(left, _op, right) => {
53            collect_references(left, fact_refs, rule_refs);
54            collect_references(right, fact_refs, rule_refs);
55        }
56        ExpressionKind::Comparison(left, _op, right) => {
57            collect_references(left, fact_refs, rule_refs);
58            collect_references(right, fact_refs, rule_refs);
59        }
60        ExpressionKind::LogicalAnd(left, right) => {
61            collect_references(left, fact_refs, rule_refs);
62            collect_references(right, fact_refs, rule_refs);
63        }
64        ExpressionKind::LogicalOr(left, right) => {
65            collect_references(left, fact_refs, rule_refs);
66            collect_references(right, fact_refs, rule_refs);
67        }
68        ExpressionKind::LogicalNegation(inner, _negation_type) => {
69            collect_references(inner, fact_refs, rule_refs);
70        }
71        ExpressionKind::UnitConversion(value, _target) => {
72            collect_references(value, fact_refs, rule_refs);
73        }
74        ExpressionKind::MathematicalOperator(_op, operand) => {
75            collect_references(operand, fact_refs, rule_refs);
76        }
77        ExpressionKind::FactHasAnyValue(fact_ref) => {
78            fact_refs.insert(fact_ref.reference.clone());
79        }
80        ExpressionKind::Veto(_) | ExpressionKind::Literal(_) => {}
81    }
82}
83
84/// Build a dependency graph showing which rules depend on which other rules.
85///
86/// Returns a map: rule_name -> set of rule names it depends on.
87/// This graph is used for topological sorting to determine execution order
88/// and for detecting circular dependencies.
89///
90/// # Examples
91/// ```text
92/// Given rules:
93///   rule total = subtotal?
94///   rule subtotal = price * quantity
95///
96/// Returns:
97///   {"total": {"subtotal"}, "subtotal": {}}
98/// ```
99pub fn build_dependency_graph(rules: &[LemmaRule]) -> HashMap<String, HashSet<String>> {
100    let mut graph = HashMap::new();
101
102    for rule in rules {
103        let mut dependencies = HashSet::new();
104
105        // Extract rule references from the main expression
106        extract_rule_references(&rule.expression, &mut dependencies);
107
108        // Extract rule references from unless clauses
109        for unless_clause in &rule.unless_clauses {
110            extract_rule_references(&unless_clause.condition, &mut dependencies);
111            extract_rule_references(&unless_clause.result, &mut dependencies);
112        }
113
114        graph.insert(rule.name.clone(), dependencies);
115    }
116
117    graph
118}
119
120/// Extract only rule references from an expression
121fn extract_rule_references(expr: &Expression, references: &mut HashSet<String>) {
122    match &expr.kind {
123        ExpressionKind::RuleReference(rule_ref) => {
124            let rule_name = if rule_ref.reference.len() > 1 {
125                rule_ref.reference.join(".")
126            } else {
127                rule_ref.reference.last().unwrap_or(&String::new()).clone()
128            };
129            references.insert(rule_name);
130        }
131        ExpressionKind::LogicalAnd(left, right)
132        | ExpressionKind::LogicalOr(left, right)
133        | ExpressionKind::Arithmetic(left, _, right)
134        | ExpressionKind::Comparison(left, _, right) => {
135            extract_rule_references(left, references);
136            extract_rule_references(right, references);
137        }
138        ExpressionKind::UnitConversion(inner, _)
139        | ExpressionKind::LogicalNegation(inner, _)
140        | ExpressionKind::MathematicalOperator(_, inner) => {
141            extract_rule_references(inner, references);
142        }
143        ExpressionKind::Veto(_)
144        | ExpressionKind::FactHasAnyValue(_)
145        | ExpressionKind::FactReference(_)
146        | ExpressionKind::Literal(_) => {}
147    }
148}
149
150/// Find all missing facts and rules for a rule.
151///
152/// Returns (missing_facts, missing_rules) where:
153/// - missing_facts: Facts that have type annotations (not provided)
154/// - missing_rules: Rules that this rule depends on that couldn't be evaluated
155///
156/// Used to provide helpful error messages about what inputs are needed
157/// to successfully evaluate a rule.
158///
159/// # Examples
160/// ```text
161/// Given:
162///   fact price: number
163///   rule total = price * 2
164///
165/// Returns: (["price [number]"], [])
166/// ```
167pub fn find_missing_dependencies(
168    rule: &LemmaRule,
169    document_facts: &[LemmaFact],
170    evaluated_results: &[RuleResult],
171) -> (Vec<String>, Vec<String>) {
172    let refs = extract_references(&rule.expression);
173
174    // Also collect from unless clauses
175    let mut all_fact_refs = refs.facts;
176    let mut all_rule_refs = refs.rules;
177
178    for unless_clause in &rule.unless_clauses {
179        let unless_refs = extract_references(&unless_clause.condition);
180        all_fact_refs.extend(unless_refs.facts);
181        all_rule_refs.extend(unless_refs.rules);
182
183        let result_refs = extract_references(&unless_clause.result);
184        all_fact_refs.extend(result_refs.facts);
185        all_rule_refs.extend(result_refs.rules);
186    }
187
188    // Find missing facts (have type annotations)
189    let mut missing_facts = Vec::new();
190    for fact_ref in all_fact_refs {
191        let fact_name = fact_ref.join(".");
192
193        if let Some(fact) = document_facts
194            .iter()
195            .find(|f| fact_display_name(f) == fact_name)
196        {
197            if let FactValue::TypeAnnotation(type_ann) = &fact.value {
198                let formatted = format!("{} [{}]", fact_name, format_type_annotation(type_ann));
199                missing_facts.push(formatted);
200            }
201        }
202    }
203
204    // Find missing rules (couldn't be evaluated or have missing facts)
205    let mut missing_rules = Vec::new();
206    for rule_ref in all_rule_refs {
207        let rule_name = rule_ref.join(".");
208
209        // Check if this rule was evaluated successfully
210        if let Some(result) = evaluated_results.iter().find(|r| r.rule_name == rule_name) {
211            // If it has no result or has missing_facts, it couldn't be evaluated
212            if result.result.is_none() {
213                missing_rules.push(rule_name);
214            }
215        }
216    }
217
218    missing_facts.sort();
219    missing_rules.sort();
220
221    (missing_facts, missing_rules)
222}
223
224/// Recursively find all facts required by a rule, following rule dependencies.
225///
226/// This traces through the entire dependency tree:
227/// - If rule A depends on rule B which depends on fact X, this returns fact X for rule A
228/// - Handles cycles gracefully by tracking visited rules
229/// - Only returns facts with type annotations (facts without values)
230///
231/// Returns a set of fact names that are required (directly or transitively).
232///
233/// # Examples
234/// ```text
235/// Given:
236///   fact quantity: number
237///   rule subtotal = price * quantity
238///   rule total = subtotal? + shipping
239///
240/// For rule "total":
241///   Returns: {"quantity"} (price and shipping must have values)
242/// ```
243pub fn find_required_facts_recursive(
244    rule: &LemmaRule,
245    all_rules: &[LemmaRule],
246    document_facts: &[LemmaFact],
247) -> HashSet<String> {
248    let mut required_facts = HashSet::new();
249    let mut visited_rules = HashSet::new();
250
251    collect_required_facts_recursive(
252        rule,
253        all_rules,
254        document_facts,
255        &mut required_facts,
256        &mut visited_rules,
257    );
258
259    required_facts
260}
261
262/// Helper function to recursively collect required facts
263fn collect_required_facts_recursive(
264    rule: &LemmaRule,
265    all_rules: &[LemmaRule],
266    document_facts: &[LemmaFact],
267    required_facts: &mut HashSet<String>,
268    visited_rules: &mut HashSet<String>,
269) {
270    // Prevent infinite recursion from circular dependencies
271    if visited_rules.contains(&rule.name) {
272        return;
273    }
274    visited_rules.insert(rule.name.clone());
275
276    // Extract direct fact and rule references
277    let refs = extract_references(&rule.expression);
278    let mut all_fact_refs = refs.facts;
279    let mut all_rule_refs = refs.rules;
280
281    // Collect from unless clauses
282    for unless_clause in &rule.unless_clauses {
283        let cond_refs = extract_references(&unless_clause.condition);
284        all_fact_refs.extend(cond_refs.facts);
285        all_rule_refs.extend(cond_refs.rules);
286
287        let res_refs = extract_references(&unless_clause.result);
288        all_fact_refs.extend(res_refs.facts);
289        all_rule_refs.extend(res_refs.rules);
290    }
291
292    // Add direct fact references (only those with type annotations - requiring values)
293    for fact_ref in all_fact_refs {
294        let fact_name = fact_ref.join(".");
295        if let Some(fact) = document_facts
296            .iter()
297            .find(|f| fact_display_name(f) == fact_name)
298        {
299            if matches!(fact.value, FactValue::TypeAnnotation(_)) {
300                required_facts.insert(fact_name);
301            }
302        }
303    }
304
305    // Recursively process rule dependencies
306    for rule_ref in all_rule_refs {
307        let rule_name = rule_ref.join(".");
308        if let Some(dep_rule) = all_rules.iter().find(|r| r.name == rule_name) {
309            collect_required_facts_recursive(
310                dep_rule,
311                all_rules,
312                document_facts,
313                required_facts,
314                visited_rules,
315            );
316        }
317    }
318}
319
320/// Format a type annotation for display
321fn format_type_annotation(type_ann: &crate::TypeAnnotation) -> String {
322    use crate::{LemmaType, TypeAnnotation};
323    match type_ann {
324        TypeAnnotation::LemmaType(lemma_type) => match lemma_type {
325            LemmaType::Boolean => "boolean".to_string(),
326            LemmaType::Number => "number".to_string(),
327            LemmaType::Money => "money".to_string(),
328            LemmaType::Text => "text".to_string(),
329            LemmaType::Date => "date".to_string(),
330            LemmaType::Duration => "duration".to_string(),
331            LemmaType::Percentage => "percentage".to_string(),
332            LemmaType::Mass => "mass".to_string(),
333            LemmaType::Length => "length".to_string(),
334            LemmaType::Volume => "volume".to_string(),
335            LemmaType::Data => "datasize".to_string(),
336            LemmaType::Energy => "energy".to_string(),
337            LemmaType::Power => "power".to_string(),
338            LemmaType::Pressure => "pressure".to_string(),
339            LemmaType::Temperature => "temperature".to_string(),
340            LemmaType::Force => "force".to_string(),
341            LemmaType::Frequency => "frequency".to_string(),
342            LemmaType::Regex => "regex".to_string(),
343        },
344    }
345}
346
347/// Get a display name for a fact
348///
349/// Local facts use their name directly.
350/// Foreign facts join their reference path with dots.
351pub fn fact_display_name(fact: &LemmaFact) -> String {
352    match &fact.fact_type {
353        FactType::Local(name) => name.clone(),
354        FactType::Foreign(foreign_ref) => foreign_ref.reference.join("."),
355    }
356}