1use crate::{Expression, ExpressionKind, FactType, FactValue, LemmaFact, LemmaRule, RuleResult};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Default)]
13pub struct References {
14 pub facts: HashSet<Vec<String>>,
16 pub rules: HashSet<Vec<String>>,
18}
19
20pub fn extract_references(expr: &Expression) -> References {
34 let mut refs = References::default();
35 collect_references(expr, &mut refs.facts, &mut refs.rules);
36 refs
37}
38
39fn collect_references(
41 expr: &Expression,
42 fact_refs: &mut HashSet<Vec<String>>,
43 rule_refs: &mut HashSet<Vec<String>>,
44) {
45 match &expr.kind {
46 ExpressionKind::FactReference(fact_ref) => {
47 fact_refs.insert(fact_ref.reference.clone());
48 }
49 ExpressionKind::RuleReference(rule_ref) => {
50 rule_refs.insert(rule_ref.reference.clone());
51 }
52 ExpressionKind::Arithmetic(left, _op, right) => {
53 collect_references(left, fact_refs, rule_refs);
54 collect_references(right, fact_refs, rule_refs);
55 }
56 ExpressionKind::Comparison(left, _op, right) => {
57 collect_references(left, fact_refs, rule_refs);
58 collect_references(right, fact_refs, rule_refs);
59 }
60 ExpressionKind::LogicalAnd(left, right) => {
61 collect_references(left, fact_refs, rule_refs);
62 collect_references(right, fact_refs, rule_refs);
63 }
64 ExpressionKind::LogicalOr(left, right) => {
65 collect_references(left, fact_refs, rule_refs);
66 collect_references(right, fact_refs, rule_refs);
67 }
68 ExpressionKind::LogicalNegation(inner, _negation_type) => {
69 collect_references(inner, fact_refs, rule_refs);
70 }
71 ExpressionKind::UnitConversion(value, _target) => {
72 collect_references(value, fact_refs, rule_refs);
73 }
74 ExpressionKind::MathematicalOperator(_op, operand) => {
75 collect_references(operand, fact_refs, rule_refs);
76 }
77 ExpressionKind::FactHasAnyValue(fact_ref) => {
78 fact_refs.insert(fact_ref.reference.clone());
79 }
80 ExpressionKind::Veto(_) | ExpressionKind::Literal(_) => {}
81 }
82}
83
84pub fn build_dependency_graph(rules: &[LemmaRule]) -> HashMap<String, HashSet<String>> {
100 let mut graph = HashMap::new();
101
102 for rule in rules {
103 let mut dependencies = HashSet::new();
104
105 extract_rule_references(&rule.expression, &mut dependencies);
107
108 for unless_clause in &rule.unless_clauses {
110 extract_rule_references(&unless_clause.condition, &mut dependencies);
111 extract_rule_references(&unless_clause.result, &mut dependencies);
112 }
113
114 graph.insert(rule.name.clone(), dependencies);
115 }
116
117 graph
118}
119
120fn extract_rule_references(expr: &Expression, references: &mut HashSet<String>) {
122 match &expr.kind {
123 ExpressionKind::RuleReference(rule_ref) => {
124 let rule_name = if rule_ref.reference.len() > 1 {
125 rule_ref.reference.join(".")
126 } else {
127 rule_ref.reference.last().unwrap_or(&String::new()).clone()
128 };
129 references.insert(rule_name);
130 }
131 ExpressionKind::LogicalAnd(left, right)
132 | ExpressionKind::LogicalOr(left, right)
133 | ExpressionKind::Arithmetic(left, _, right)
134 | ExpressionKind::Comparison(left, _, right) => {
135 extract_rule_references(left, references);
136 extract_rule_references(right, references);
137 }
138 ExpressionKind::UnitConversion(inner, _)
139 | ExpressionKind::LogicalNegation(inner, _)
140 | ExpressionKind::MathematicalOperator(_, inner) => {
141 extract_rule_references(inner, references);
142 }
143 ExpressionKind::Veto(_)
144 | ExpressionKind::FactHasAnyValue(_)
145 | ExpressionKind::FactReference(_)
146 | ExpressionKind::Literal(_) => {}
147 }
148}
149
150pub fn find_missing_dependencies(
168 rule: &LemmaRule,
169 document_facts: &[LemmaFact],
170 evaluated_results: &[RuleResult],
171) -> (Vec<String>, Vec<String>) {
172 let refs = extract_references(&rule.expression);
173
174 let mut all_fact_refs = refs.facts;
176 let mut all_rule_refs = refs.rules;
177
178 for unless_clause in &rule.unless_clauses {
179 let unless_refs = extract_references(&unless_clause.condition);
180 all_fact_refs.extend(unless_refs.facts);
181 all_rule_refs.extend(unless_refs.rules);
182
183 let result_refs = extract_references(&unless_clause.result);
184 all_fact_refs.extend(result_refs.facts);
185 all_rule_refs.extend(result_refs.rules);
186 }
187
188 let mut missing_facts = Vec::new();
190 for fact_ref in all_fact_refs {
191 let fact_name = fact_ref.join(".");
192
193 if let Some(fact) = document_facts
194 .iter()
195 .find(|f| fact_display_name(f) == fact_name)
196 {
197 if let FactValue::TypeAnnotation(type_ann) = &fact.value {
198 let formatted = format!("{} [{}]", fact_name, format_type_annotation(type_ann));
199 missing_facts.push(formatted);
200 }
201 }
202 }
203
204 let mut missing_rules = Vec::new();
206 for rule_ref in all_rule_refs {
207 let rule_name = rule_ref.join(".");
208
209 if let Some(result) = evaluated_results.iter().find(|r| r.rule_name == rule_name) {
211 if result.result.is_none() {
213 missing_rules.push(rule_name);
214 }
215 }
216 }
217
218 missing_facts.sort();
219 missing_rules.sort();
220
221 (missing_facts, missing_rules)
222}
223
224pub fn find_required_facts_recursive(
244 rule: &LemmaRule,
245 all_rules: &[LemmaRule],
246 document_facts: &[LemmaFact],
247) -> HashSet<String> {
248 let mut required_facts = HashSet::new();
249 let mut visited_rules = HashSet::new();
250
251 collect_required_facts_recursive(
252 rule,
253 all_rules,
254 document_facts,
255 &mut required_facts,
256 &mut visited_rules,
257 );
258
259 required_facts
260}
261
262fn collect_required_facts_recursive(
264 rule: &LemmaRule,
265 all_rules: &[LemmaRule],
266 document_facts: &[LemmaFact],
267 required_facts: &mut HashSet<String>,
268 visited_rules: &mut HashSet<String>,
269) {
270 if visited_rules.contains(&rule.name) {
272 return;
273 }
274 visited_rules.insert(rule.name.clone());
275
276 let refs = extract_references(&rule.expression);
278 let mut all_fact_refs = refs.facts;
279 let mut all_rule_refs = refs.rules;
280
281 for unless_clause in &rule.unless_clauses {
283 let cond_refs = extract_references(&unless_clause.condition);
284 all_fact_refs.extend(cond_refs.facts);
285 all_rule_refs.extend(cond_refs.rules);
286
287 let res_refs = extract_references(&unless_clause.result);
288 all_fact_refs.extend(res_refs.facts);
289 all_rule_refs.extend(res_refs.rules);
290 }
291
292 for fact_ref in all_fact_refs {
294 let fact_name = fact_ref.join(".");
295 if let Some(fact) = document_facts
296 .iter()
297 .find(|f| fact_display_name(f) == fact_name)
298 {
299 if matches!(fact.value, FactValue::TypeAnnotation(_)) {
300 required_facts.insert(fact_name);
301 }
302 }
303 }
304
305 for rule_ref in all_rule_refs {
307 let rule_name = rule_ref.join(".");
308 if let Some(dep_rule) = all_rules.iter().find(|r| r.name == rule_name) {
309 collect_required_facts_recursive(
310 dep_rule,
311 all_rules,
312 document_facts,
313 required_facts,
314 visited_rules,
315 );
316 }
317 }
318}
319
320fn format_type_annotation(type_ann: &crate::TypeAnnotation) -> String {
322 use crate::{LemmaType, TypeAnnotation};
323 match type_ann {
324 TypeAnnotation::LemmaType(lemma_type) => match lemma_type {
325 LemmaType::Boolean => "boolean".to_string(),
326 LemmaType::Number => "number".to_string(),
327 LemmaType::Money => "money".to_string(),
328 LemmaType::Text => "text".to_string(),
329 LemmaType::Date => "date".to_string(),
330 LemmaType::Duration => "duration".to_string(),
331 LemmaType::Percentage => "percentage".to_string(),
332 LemmaType::Mass => "mass".to_string(),
333 LemmaType::Length => "length".to_string(),
334 LemmaType::Volume => "volume".to_string(),
335 LemmaType::Data => "datasize".to_string(),
336 LemmaType::Energy => "energy".to_string(),
337 LemmaType::Power => "power".to_string(),
338 LemmaType::Pressure => "pressure".to_string(),
339 LemmaType::Temperature => "temperature".to_string(),
340 LemmaType::Force => "force".to_string(),
341 LemmaType::Frequency => "frequency".to_string(),
342 LemmaType::Regex => "regex".to_string(),
343 },
344 }
345}
346
347pub fn fact_display_name(fact: &LemmaFact) -> String {
352 match &fact.fact_type {
353 FactType::Local(name) => name.clone(),
354 FactType::Foreign(foreign_ref) => foreign_ref.reference.join("."),
355 }
356}