1use crate::{
9 Expression, ExpressionKind, FactReference, FactType, FactValue, LemmaDoc, LemmaFact,
10 LemmaResult, LemmaRule, RulePath,
11};
12use std::collections::{HashMap, HashSet};
13
14#[derive(Debug, Clone, Default)]
16pub struct References {
17 pub facts: HashSet<FactReference>,
19 pub rules: HashSet<Vec<String>>,
22}
23
24pub fn extract_references(expr: &Expression) -> References {
38 let mut refs = References::default();
39 collect_references(expr, &mut refs.facts, &mut refs.rules);
40 refs
41}
42
43fn collect_references(
45 expr: &Expression,
46 fact_refs: &mut HashSet<FactReference>,
47 rule_refs: &mut HashSet<Vec<String>>,
48) {
49 match &expr.kind {
50 ExpressionKind::FactReference(fact_ref) => {
51 fact_refs.insert(fact_ref.clone());
52 }
53 ExpressionKind::RuleReference(rule_ref) => {
54 rule_refs.insert(rule_ref.reference.clone());
55 }
56 ExpressionKind::Arithmetic(left, _op, right) => {
57 collect_references(left, fact_refs, rule_refs);
58 collect_references(right, fact_refs, rule_refs);
59 }
60 ExpressionKind::Comparison(left, _op, right) => {
61 collect_references(left, fact_refs, rule_refs);
62 collect_references(right, fact_refs, rule_refs);
63 }
64 ExpressionKind::LogicalAnd(left, right) => {
65 collect_references(left, fact_refs, rule_refs);
66 collect_references(right, fact_refs, rule_refs);
67 }
68 ExpressionKind::LogicalOr(left, right) => {
69 collect_references(left, fact_refs, rule_refs);
70 collect_references(right, fact_refs, rule_refs);
71 }
72 ExpressionKind::LogicalNegation(inner, _negation_type) => {
73 collect_references(inner, fact_refs, rule_refs);
74 }
75 ExpressionKind::UnitConversion(value, _target) => {
76 collect_references(value, fact_refs, rule_refs);
77 }
78 ExpressionKind::MathematicalOperator(_op, operand) => {
79 collect_references(operand, fact_refs, rule_refs);
80 }
81 ExpressionKind::FactHasAnyValue(fact_ref) => {
82 fact_refs.insert(fact_ref.clone());
83 }
84 ExpressionKind::Veto(_) | ExpressionKind::Literal(_) => {}
85 }
86}
87
88pub fn find_required_facts_recursive(
108 rule: &LemmaRule,
109 all_rules: &[LemmaRule],
110 document_facts: &[LemmaFact],
111) -> HashSet<String> {
112 let mut required_facts = HashSet::new();
113 let mut visited_rules = HashSet::new();
114
115 collect_required_facts_recursive(
116 rule,
117 all_rules,
118 document_facts,
119 &mut required_facts,
120 &mut visited_rules,
121 );
122
123 required_facts
124}
125
126fn collect_required_facts_recursive(
128 rule: &LemmaRule,
129 all_rules: &[LemmaRule],
130 document_facts: &[LemmaFact],
131 required_facts: &mut HashSet<String>,
132 visited_rules: &mut HashSet<String>,
133) {
134 if visited_rules.contains(&rule.name) {
136 return;
137 }
138 visited_rules.insert(rule.name.clone());
139
140 let refs = extract_references(&rule.expression);
142 let mut all_fact_refs = refs.facts;
143 let mut all_rule_refs = refs.rules;
144
145 for unless_clause in &rule.unless_clauses {
147 let cond_refs = extract_references(&unless_clause.condition);
148 all_fact_refs.extend(cond_refs.facts);
149 all_rule_refs.extend(cond_refs.rules);
150
151 let res_refs = extract_references(&unless_clause.result);
152 all_fact_refs.extend(res_refs.facts);
153 all_rule_refs.extend(res_refs.rules);
154 }
155
156 for fact_ref in all_fact_refs {
158 let fact_name = fact_ref.reference.join(".");
159 if let Some(fact) = document_facts
160 .iter()
161 .find(|f| fact_display_name(f) == fact_name)
162 {
163 if matches!(fact.value, FactValue::TypeAnnotation(_)) {
164 required_facts.insert(fact_name);
165 }
166 }
167 }
168
169 for rule_ref in all_rule_refs {
171 let rule_name = rule_ref.join(".");
172 if let Some(dep_rule) = all_rules.iter().find(|r| r.name == rule_name) {
173 collect_required_facts_recursive(
174 dep_rule,
175 all_rules,
176 document_facts,
177 required_facts,
178 visited_rules,
179 );
180 }
181 }
182}
183
184pub fn fact_display_name(fact: &LemmaFact) -> String {
189 match &fact.fact_type {
190 FactType::Local(name) => name.clone(),
191 FactType::Foreign(foreign_ref) => foreign_ref.reference.join("."),
192 }
193}
194
195fn extract_rule_paths(
201 expr: &Expression,
202 current_doc: &LemmaDoc,
203 all_documents: &HashMap<String, LemmaDoc>,
204 paths: &mut HashSet<RulePath>,
205) -> LemmaResult<()> {
206 match &expr.kind {
207 ExpressionKind::RuleReference(rule_ref) => {
208 let path = RulePath::from_reference(&rule_ref.reference, current_doc, all_documents)?;
209 paths.insert(path);
210 }
211 ExpressionKind::LogicalAnd(left, right)
212 | ExpressionKind::LogicalOr(left, right)
213 | ExpressionKind::Arithmetic(left, _, right)
214 | ExpressionKind::Comparison(left, _, right) => {
215 extract_rule_paths(left, current_doc, all_documents, paths)?;
216 extract_rule_paths(right, current_doc, all_documents, paths)?;
217 }
218 ExpressionKind::UnitConversion(inner, _)
219 | ExpressionKind::LogicalNegation(inner, _)
220 | ExpressionKind::MathematicalOperator(_, inner) => {
221 extract_rule_paths(inner, current_doc, all_documents, paths)?;
222 }
223 _ => {}
224 }
225 Ok(())
226}
227
228pub fn build_dependency_graph(
235 doc: &LemmaDoc,
236 documents: &HashMap<String, LemmaDoc>,
237) -> LemmaResult<HashMap<RulePath, HashSet<RulePath>>> {
238 use std::collections::VecDeque;
239
240 let mut graph = HashMap::new();
241 let mut queue = VecDeque::new();
242
243 for rule in &doc.rules {
245 let path = RulePath {
246 rule: rule.name.clone(),
247 segments: vec![],
248 };
249 queue.push_back((path, rule, doc));
250 }
251
252 while let Some((path, rule, rule_doc)) = queue.pop_front() {
254 if graph.contains_key(&path) {
256 continue;
257 }
258
259 let mut dependencies = HashSet::new();
261 extract_rule_paths(&rule.expression, rule_doc, documents, &mut dependencies)?;
262 for uc in &rule.unless_clauses {
263 extract_rule_paths(&uc.condition, rule_doc, documents, &mut dependencies)?;
264 extract_rule_paths(&uc.result, rule_doc, documents, &mut dependencies)?;
265 }
266
267 graph.insert(path.clone(), dependencies.clone());
269
270 for dep_path in dependencies {
272 if !graph.contains_key(&dep_path) {
273 let target_doc_name = dep_path.target_doc(&doc.name);
274 let target_doc = documents.get(target_doc_name).ok_or_else(|| {
275 crate::LemmaError::Engine(format!(
276 "Rule {} references document '{}' which does not exist",
277 path, target_doc_name
278 ))
279 })?;
280
281 let target_rule = target_doc
282 .rules
283 .iter()
284 .find(|r| r.name == dep_path.rule)
285 .ok_or_else(|| {
286 crate::LemmaError::Engine(format!(
287 "Rule {} references rule '{}' in document '{}' which does not exist",
288 path, dep_path.rule, target_doc_name
289 ))
290 })?;
291
292 queue.push_back((dep_path, target_rule, target_doc));
293 }
294 }
295 }
296
297 Ok(graph)
298}