lemma/
analysis.rs

1//! Reference and dependency analysis utilities
2//!
3//! This module provides shared infrastructure for extracting references
4//! from expressions and analyzing dependencies between rules.
5//!
6//! Used by both semantic validation and evaluation.
7
8use crate::{
9    Expression, ExpressionKind, FactReference, FactType, FactValue, LemmaDoc, LemmaFact,
10    LemmaResult, LemmaRule, RulePath,
11};
12use std::collections::{HashMap, HashSet};
13
14/// References extracted from an expression
15#[derive(Debug, Clone, Default)]
16pub struct References {
17    /// Fact references (e.g., FactReference with reference ["employee", "name"])
18    pub facts: HashSet<FactReference>,
19    /// Rule references as raw segments (e.g., ["employee", "is_eligible"])
20    /// Note: These are syntactic references, not yet resolved to RulePath which requires semantic context
21    pub rules: HashSet<Vec<String>>,
22}
23
24/// Extract all fact and rule references from an expression.
25///
26/// Recursively walks the expression tree to find all references to facts and rules.
27/// Useful for dependency analysis and validation.
28///
29/// # Examples
30/// ```text
31/// Expression: price * quantity
32/// Returns: facts = ["price", "quantity"], rules = []
33///
34/// Expression: base_amount + adjustment?
35/// Returns: facts = ["base_amount"], rules = ["adjustment"]
36/// ```
37pub fn extract_references(expr: &Expression) -> References {
38    let mut refs = References::default();
39    collect_references(expr, &mut refs.facts, &mut refs.rules);
40    refs
41}
42
43/// Recursively collect all fact and rule references from an expression
44fn collect_references(
45    expr: &Expression,
46    fact_refs: &mut HashSet<FactReference>,
47    rule_refs: &mut HashSet<Vec<String>>,
48) {
49    match &expr.kind {
50        ExpressionKind::FactReference(fact_ref) => {
51            fact_refs.insert(fact_ref.clone());
52        }
53        ExpressionKind::RuleReference(rule_ref) => {
54            rule_refs.insert(rule_ref.reference.clone());
55        }
56        ExpressionKind::Arithmetic(left, _op, right) => {
57            collect_references(left, fact_refs, rule_refs);
58            collect_references(right, fact_refs, rule_refs);
59        }
60        ExpressionKind::Comparison(left, _op, right) => {
61            collect_references(left, fact_refs, rule_refs);
62            collect_references(right, fact_refs, rule_refs);
63        }
64        ExpressionKind::LogicalAnd(left, right) => {
65            collect_references(left, fact_refs, rule_refs);
66            collect_references(right, fact_refs, rule_refs);
67        }
68        ExpressionKind::LogicalOr(left, right) => {
69            collect_references(left, fact_refs, rule_refs);
70            collect_references(right, fact_refs, rule_refs);
71        }
72        ExpressionKind::LogicalNegation(inner, _negation_type) => {
73            collect_references(inner, fact_refs, rule_refs);
74        }
75        ExpressionKind::UnitConversion(value, _target) => {
76            collect_references(value, fact_refs, rule_refs);
77        }
78        ExpressionKind::MathematicalOperator(_op, operand) => {
79            collect_references(operand, fact_refs, rule_refs);
80        }
81        ExpressionKind::FactHasAnyValue(fact_ref) => {
82            fact_refs.insert(fact_ref.clone());
83        }
84        ExpressionKind::Veto(_) | ExpressionKind::Literal(_) => {}
85    }
86}
87
88/// Recursively find all facts required by a rule, following rule dependencies.
89///
90/// This traces through the entire dependency tree:
91/// - If rule A depends on rule B which depends on fact X, this returns fact X for rule A
92/// - Handles cycles gracefully by tracking visited rules
93/// - Only returns facts with type annotations (facts without values)
94///
95/// Returns a set of fact names that are required (directly or transitively).
96///
97/// # Examples
98/// ```text
99/// Given:
100///   fact quantity: number
101///   rule subtotal = price * quantity
102///   rule total = subtotal? + shipping
103///
104/// For rule "total":
105///   Returns: {"quantity"} (price and shipping must have values)
106/// ```
107pub fn find_required_facts_recursive(
108    rule: &LemmaRule,
109    all_rules: &[LemmaRule],
110    document_facts: &[LemmaFact],
111) -> HashSet<String> {
112    let mut required_facts = HashSet::new();
113    let mut visited_rules = HashSet::new();
114
115    collect_required_facts_recursive(
116        rule,
117        all_rules,
118        document_facts,
119        &mut required_facts,
120        &mut visited_rules,
121    );
122
123    required_facts
124}
125
126/// Helper function to recursively collect required facts
127fn collect_required_facts_recursive(
128    rule: &LemmaRule,
129    all_rules: &[LemmaRule],
130    document_facts: &[LemmaFact],
131    required_facts: &mut HashSet<String>,
132    visited_rules: &mut HashSet<String>,
133) {
134    // Prevent infinite recursion from circular dependencies
135    if visited_rules.contains(&rule.name) {
136        return;
137    }
138    visited_rules.insert(rule.name.clone());
139
140    // Extract direct fact and rule references
141    let refs = extract_references(&rule.expression);
142    let mut all_fact_refs = refs.facts;
143    let mut all_rule_refs = refs.rules;
144
145    // Collect from unless clauses
146    for unless_clause in &rule.unless_clauses {
147        let cond_refs = extract_references(&unless_clause.condition);
148        all_fact_refs.extend(cond_refs.facts);
149        all_rule_refs.extend(cond_refs.rules);
150
151        let res_refs = extract_references(&unless_clause.result);
152        all_fact_refs.extend(res_refs.facts);
153        all_rule_refs.extend(res_refs.rules);
154    }
155
156    // Add direct fact references (only those with type annotations - requiring values)
157    for fact_ref in all_fact_refs {
158        let fact_name = fact_ref.reference.join(".");
159        if let Some(fact) = document_facts
160            .iter()
161            .find(|f| fact_display_name(f) == fact_name)
162        {
163            if matches!(fact.value, FactValue::TypeAnnotation(_)) {
164                required_facts.insert(fact_name);
165            }
166        }
167    }
168
169    // Recursively process rule dependencies
170    for rule_ref in all_rule_refs {
171        let rule_name = rule_ref.join(".");
172        if let Some(dep_rule) = all_rules.iter().find(|r| r.name == rule_name) {
173            collect_required_facts_recursive(
174                dep_rule,
175                all_rules,
176                document_facts,
177                required_facts,
178                visited_rules,
179            );
180        }
181    }
182}
183
184/// Get a display name for a fact
185///
186/// Local facts use their name directly.
187/// Foreign facts join their reference path with dots.
188pub fn fact_display_name(fact: &LemmaFact) -> String {
189    match &fact.fact_type {
190        FactType::Local(name) => name.clone(),
191        FactType::Foreign(foreign_ref) => foreign_ref.reference.join("."),
192    }
193}
194
195/// Extract rule paths from an expression for dependency analysis across document references.
196///
197/// Resolves rule references to `RulePath` instances that include the full
198/// fact traversal path (e.g., `employee.salary?` where `employee` is a fact
199/// referencing another document). Used internally by dependency graph building.
200fn extract_rule_paths(
201    expr: &Expression,
202    current_doc: &LemmaDoc,
203    all_documents: &HashMap<String, LemmaDoc>,
204    paths: &mut HashSet<RulePath>,
205) -> LemmaResult<()> {
206    match &expr.kind {
207        ExpressionKind::RuleReference(rule_ref) => {
208            let path = RulePath::from_reference(&rule_ref.reference, current_doc, all_documents)?;
209            paths.insert(path);
210        }
211        ExpressionKind::LogicalAnd(left, right)
212        | ExpressionKind::LogicalOr(left, right)
213        | ExpressionKind::Arithmetic(left, _, right)
214        | ExpressionKind::Comparison(left, _, right) => {
215            extract_rule_paths(left, current_doc, all_documents, paths)?;
216            extract_rule_paths(right, current_doc, all_documents, paths)?;
217        }
218        ExpressionKind::UnitConversion(inner, _)
219        | ExpressionKind::LogicalNegation(inner, _)
220        | ExpressionKind::MathematicalOperator(_, inner) => {
221            extract_rule_paths(inner, current_doc, all_documents, paths)?;
222        }
223        _ => {}
224    }
225    Ok(())
226}
227
228/// Build dependency graph for all reachable rules
229///
230/// Starting from the document being evaluated, discovers all rules
231/// (local + rules from documents referenced by facts) and extracts dependencies in a single traversal.
232///
233/// Returns: RulePath -> Set of RulePaths it depends on
234pub fn build_dependency_graph(
235    doc: &LemmaDoc,
236    documents: &HashMap<String, LemmaDoc>,
237) -> LemmaResult<HashMap<RulePath, HashSet<RulePath>>> {
238    use std::collections::VecDeque;
239
240    let mut graph = HashMap::new();
241    let mut queue = VecDeque::new();
242
243    // Start with rules from document being evaluated
244    for rule in &doc.rules {
245        let path = RulePath {
246            rule: rule.name.clone(),
247            segments: vec![],
248        };
249        queue.push_back((path, rule, doc));
250    }
251
252    // BFS: discover rules and build dependencies simultaneously
253    while let Some((path, rule, rule_doc)) = queue.pop_front() {
254        // Skip if already processed
255        if graph.contains_key(&path) {
256            continue;
257        }
258
259        // Extract dependencies for this rule (single traversal)
260        let mut dependencies = HashSet::new();
261        extract_rule_paths(&rule.expression, rule_doc, documents, &mut dependencies)?;
262        for uc in &rule.unless_clauses {
263            extract_rule_paths(&uc.condition, rule_doc, documents, &mut dependencies)?;
264            extract_rule_paths(&uc.result, rule_doc, documents, &mut dependencies)?;
265        }
266
267        // Store in graph
268        graph.insert(path.clone(), dependencies.clone());
269
270        // Queue dependencies for discovery
271        for dep_path in dependencies {
272            if !graph.contains_key(&dep_path) {
273                let target_doc_name = dep_path.target_doc(&doc.name);
274                let target_doc = documents.get(target_doc_name).ok_or_else(|| {
275                    crate::LemmaError::Engine(format!(
276                        "Rule {} references document '{}' which does not exist",
277                        path, target_doc_name
278                    ))
279                })?;
280
281                let target_rule = target_doc
282                    .rules
283                    .iter()
284                    .find(|r| r.name == dep_path.rule)
285                    .ok_or_else(|| {
286                        crate::LemmaError::Engine(format!(
287                            "Rule {} references rule '{}' in document '{}' which does not exist",
288                            path, dep_path.rule, target_doc_name
289                        ))
290                    })?;
291
292                queue.push_back((dep_path, target_rule, target_doc));
293            }
294        }
295    }
296
297    Ok(graph)
298}