Skip to main content

lemma/planning/
slice_interface.rs

1use crate::parsing::ast::LemmaSpec;
2use crate::planning::execution_plan::ExecutionPlan;
3use crate::planning::semantics::{ExpressionKind, FactData, LemmaType, PathSegment, RulePath};
4use crate::planning::types::ResolvedSpecTypes;
5use crate::Error;
6use std::collections::{BTreeMap, HashMap, HashSet};
7use std::sync::Arc;
8
9type ResolvedTypesMap = HashMap<Arc<LemmaSpec>, ResolvedSpecTypes>;
10
11/// The resolved interface of a referenced spec within a single temporal slice.
12///
13/// Captures only what the caller actually uses: needed facts, referenced rules,
14/// and type definitions. Two SliceInterfaces are equal iff the caller sees the
15/// exact same contract from the referenced spec in both slices.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct SliceInterface {
18    pub facts: BTreeMap<String, FactKind>,
19    pub rules: BTreeMap<String, LemmaType>,
20    pub types: BTreeMap<String, LemmaType>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum FactKind {
25    Value(LemmaType),
26    SpecRef { spec_name: String },
27}
28
29impl SliceInterface {
30    /// Extract the interface of a referenced spec from a built execution plan.
31    ///
32    /// `segments` identifies the referenced spec (e.g. `[PathSegment { fact: "b", spec: "B" }]`).
33    /// Uses the plan's precomputed `needs_facts` to determine which facts matter.
34    pub(crate) fn from_plan(
35        plan: &ExecutionPlan,
36        segments: &[PathSegment],
37        resolved_types: &ResolvedTypesMap,
38        ref_spec: &Arc<LemmaSpec>,
39    ) -> Self {
40        let needed_at_segments = collect_needed_facts_at_segments(plan, segments);
41
42        let mut facts = BTreeMap::new();
43        for (path, data) in &plan.facts {
44            if path.segments != *segments {
45                continue;
46            }
47            if !needed_at_segments.contains(path.fact.as_str()) {
48                continue;
49            }
50            let kind = match data {
51                FactData::Value { value, .. } => FactKind::Value(value.lemma_type.clone()),
52                FactData::TypeDeclaration { resolved_type, .. } => {
53                    FactKind::Value(resolved_type.clone())
54                }
55                FactData::SpecRef { spec, .. } => FactKind::SpecRef {
56                    spec_name: spec.name.clone(),
57                },
58            };
59            facts.insert(path.fact.clone(), kind);
60        }
61
62        let referenced_rules = collect_referenced_rules_at_segments(plan, segments);
63        let mut rules = BTreeMap::new();
64        for rule in &plan.rules {
65            if rule.path.segments != *segments {
66                continue;
67            }
68            if !referenced_rules.contains(rule.name.as_str()) {
69                continue;
70            }
71            rules.insert(rule.name.clone(), rule.rule_type.clone());
72        }
73
74        let mut types = BTreeMap::new();
75        if let Some(spec_types) = resolved_types.get(ref_spec) {
76            for (name, lemma_type) in &spec_types.named_types {
77                types.insert(name.clone(), lemma_type.clone());
78            }
79        }
80
81        SliceInterface {
82            facts,
83            rules,
84            types,
85        }
86    }
87
88    pub fn diff(&self, other: &SliceInterface) -> Vec<String> {
89        let mut diffs = Vec::new();
90        diff_map("fact", &self.facts, &other.facts, &mut diffs, |a, b| a != b);
91        diff_map("rule", &self.rules, &other.rules, &mut diffs, |a, b| a != b);
92        diff_map("type", &self.types, &other.types, &mut diffs, |a, b| a != b);
93        diffs
94    }
95}
96
97fn diff_map<V: std::fmt::Debug>(
98    label: &str,
99    a: &BTreeMap<String, V>,
100    b: &BTreeMap<String, V>,
101    diffs: &mut Vec<String>,
102    changed: impl Fn(&V, &V) -> bool,
103) {
104    for key in a.keys() {
105        if !b.contains_key(key) {
106            diffs.push(format!("{} '{}' removed", label, key));
107        }
108    }
109    for key in b.keys() {
110        if !a.contains_key(key) {
111            diffs.push(format!("{} '{}' added", label, key));
112        }
113    }
114    for (key, val_a) in a {
115        if let Some(val_b) = b.get(key) {
116            if changed(val_a, val_b) {
117                diffs.push(format!(
118                    "{} '{}' changed: {:?} -> {:?}",
119                    label, key, val_a, val_b
120                ));
121            }
122        }
123    }
124}
125
126/// Collect fact names at `segments` depth that any root-level rule needs.
127///
128/// Uses the plan's precomputed `needs_facts` (transitive closure) and also
129/// extracts intermediate SpecRef traversal facts: if a needed FactPath or
130/// a referenced RulePath passes through a deeper segment, the linking fact at
131/// `segments` depth is itself a needed interface fact.
132fn collect_needed_facts_at_segments<'a>(
133    plan: &'a ExecutionPlan,
134    segments: &[PathSegment],
135) -> HashSet<&'a str> {
136    let mut needed = HashSet::new();
137
138    for rule in &plan.rules {
139        if !rule.path.segments.is_empty() {
140            continue;
141        }
142        for fp in &rule.needs_facts {
143            if fp.segments == *segments {
144                needed.insert(fp.fact.as_str());
145            }
146            if fp.segments.len() > segments.len() && fp.segments[..segments.len()] == *segments {
147                needed.insert(fp.segments[segments.len()].fact.as_str());
148            }
149        }
150    }
151
152    // RulePath references at deeper segments also imply an intermediate
153    // SpecRef fact at our level (e.g. `b.nested.val` means `nested` is needed).
154    let referenced_rules = collect_root_rule_paths(plan);
155    for rp in &referenced_rules {
156        if rp.segments.len() > segments.len() && rp.segments[..segments.len()] == *segments {
157            needed.insert(rp.segments[segments.len()].fact.as_str());
158        }
159    }
160
161    needed
162}
163
164/// Collect rule names at `segments` depth that root-level rules directly reference.
165///
166/// Walks root-level rule expressions for RulePath references at the dep's depth.
167/// Internal dep rules (only reachable transitively within the dep) are excluded —
168/// the caller only cares about the dep rules it explicitly uses.
169fn collect_referenced_rules_at_segments<'a>(
170    plan: &'a ExecutionPlan,
171    segments: &[PathSegment],
172) -> HashSet<&'a str> {
173    let mut referenced = HashSet::new();
174    let all_rule_paths = collect_root_rule_paths(plan);
175    for rp in &all_rule_paths {
176        if rp.segments == *segments {
177            referenced.insert(rp.rule.as_str());
178        }
179    }
180    referenced
181}
182
183/// Collect all RulePath references from root-level rule expressions.
184fn collect_root_rule_paths(plan: &ExecutionPlan) -> Vec<&RulePath> {
185    let mut paths = Vec::new();
186    for rule in &plan.rules {
187        if !rule.path.segments.is_empty() {
188            continue;
189        }
190        for branch in &rule.branches {
191            collect_rule_paths_from_expr(&branch.result, &mut paths);
192            if let Some(cond) = &branch.condition {
193                collect_rule_paths_from_expr(cond, &mut paths);
194            }
195        }
196    }
197    paths
198}
199
200fn collect_rule_paths_from_expr<'a>(
201    expr: &'a crate::planning::semantics::Expression,
202    out: &mut Vec<&'a RulePath>,
203) {
204    match &expr.kind {
205        ExpressionKind::RulePath(rp) => out.push(rp),
206        ExpressionKind::LogicalAnd(l, r)
207        | ExpressionKind::Arithmetic(l, _, r)
208        | ExpressionKind::Comparison(l, _, r) => {
209            collect_rule_paths_from_expr(l, out);
210            collect_rule_paths_from_expr(r, out);
211        }
212        ExpressionKind::UnitConversion(inner, _)
213        | ExpressionKind::LogicalNegation(inner, _)
214        | ExpressionKind::MathematicalComputation(_, inner) => {
215            collect_rule_paths_from_expr(inner, out);
216        }
217        ExpressionKind::DateRelative(_, date_expr, tolerance) => {
218            collect_rule_paths_from_expr(date_expr, out);
219            if let Some(tol) = tolerance {
220                collect_rule_paths_from_expr(tol, out);
221            }
222        }
223        ExpressionKind::DateCalendar(_, _, date_expr) => {
224            collect_rule_paths_from_expr(date_expr, out);
225        }
226        ExpressionKind::Literal(_)
227        | ExpressionKind::FactPath(_)
228        | ExpressionKind::Veto(_)
229        | ExpressionKind::Now => {}
230    }
231}
232
233/// Validate that all temporal slices of a spec see the same interface
234/// from each referenced spec.
235pub(crate) fn validate_slice_interfaces(
236    spec_name: &str,
237    slice_plans: &[ExecutionPlan],
238    resolved_types_per_slice: &[ResolvedTypesMap],
239) -> Vec<Error> {
240    if slice_plans.len() <= 1 {
241        return Vec::new();
242    }
243
244    let ref_segments = collect_ref_spec_segments(&slice_plans[0]);
245
246    let mut errors = Vec::new();
247
248    for (segments, ref_spec_arc) in &ref_segments {
249        let first_interface = SliceInterface::from_plan(
250            &slice_plans[0],
251            segments,
252            &resolved_types_per_slice[0],
253            ref_spec_arc,
254        );
255
256        for (i, plan) in slice_plans.iter().enumerate().skip(1) {
257            let ref_spec_in_slice = find_ref_spec_in_plan(plan, segments);
258            let ref_spec = ref_spec_in_slice.as_ref().unwrap_or(ref_spec_arc);
259            let slice_interface =
260                SliceInterface::from_plan(plan, segments, &resolved_types_per_slice[i], ref_spec);
261
262            if first_interface != slice_interface {
263                let diffs = first_interface.diff(&slice_interface);
264                let diff_detail = if diffs.is_empty() {
265                    String::new()
266                } else {
267                    format!(": {}", diffs.join(", "))
268                };
269                errors.push(Error::validation(
270                    format!(
271                        "Referenced spec '{}' changed its interface between temporal slices of '{}'{}\n\
272                         Create a new temporal version of '{}' to handle the interface change.",
273                        ref_spec_arc.name, spec_name, diff_detail, spec_name
274                    ),
275                    None,
276                    None::<String>,
277                ));
278                break;
279            }
280        }
281    }
282
283    errors
284}
285
286/// Find all first-level referenced spec segments, plus nested ones reachable
287/// through the plan's facts/rules.
288fn collect_ref_spec_segments(plan: &ExecutionPlan) -> Vec<(Vec<PathSegment>, Arc<LemmaSpec>)> {
289    let mut seen = HashSet::new();
290    let mut result = Vec::new();
291
292    for (path, data) in &plan.facts {
293        if let FactData::SpecRef { spec, .. } = data {
294            let mut seg = path.segments.clone();
295            seg.push(PathSegment {
296                fact: path.fact.clone(),
297                spec: spec.name.clone(),
298            });
299            let key = seg
300                .iter()
301                .map(|s| format!("{}.{}", s.fact, s.spec))
302                .collect::<Vec<_>>()
303                .join("/");
304            if seen.insert(key) {
305                result.push((seg, Arc::clone(spec)));
306            }
307        }
308    }
309
310    result
311}
312
313/// Find the Arc<LemmaSpec> for a referenced spec in a plan by matching segments.
314fn find_ref_spec_in_plan(plan: &ExecutionPlan, segments: &[PathSegment]) -> Option<Arc<LemmaSpec>> {
315    if segments.is_empty() {
316        return None;
317    }
318    let parent_segments = &segments[..segments.len() - 1];
319    let target_seg = &segments[segments.len() - 1];
320
321    for (path, data) in &plan.facts {
322        if let FactData::SpecRef { spec, .. } = data {
323            if path.segments == *parent_segments
324                && path.fact == target_seg.fact
325                && spec.name == target_seg.spec
326            {
327                return Some(Arc::clone(spec));
328            }
329        }
330    }
331    None
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use crate::planning::semantics::primitive_number;
338
339    #[test]
340    fn identical_interfaces_are_equal() {
341        let mut facts = BTreeMap::new();
342        facts.insert("x".to_string(), FactKind::Value(primitive_number().clone()));
343
344        let mut rules = BTreeMap::new();
345        rules.insert("z".to_string(), primitive_number().clone());
346
347        let a = SliceInterface {
348            facts: facts.clone(),
349            rules: rules.clone(),
350            types: BTreeMap::new(),
351        };
352        let b = SliceInterface {
353            facts,
354            rules,
355            types: BTreeMap::new(),
356        };
357        assert_eq!(a, b);
358        assert!(a.diff(&b).is_empty());
359    }
360
361    #[test]
362    fn added_fact_detected() {
363        let a = SliceInterface {
364            facts: BTreeMap::new(),
365            rules: BTreeMap::new(),
366            types: BTreeMap::new(),
367        };
368
369        let mut facts_b = BTreeMap::new();
370        facts_b.insert("y".to_string(), FactKind::Value(primitive_number().clone()));
371        let b = SliceInterface {
372            facts: facts_b,
373            rules: BTreeMap::new(),
374            types: BTreeMap::new(),
375        };
376
377        assert_ne!(a, b);
378        let diffs = a.diff(&b);
379        assert!(diffs.iter().any(|d| d.contains("'y' added")));
380    }
381
382    #[test]
383    fn removed_rule_detected() {
384        let mut rules_a = BTreeMap::new();
385        rules_a.insert("z".to_string(), primitive_number().clone());
386        let a = SliceInterface {
387            facts: BTreeMap::new(),
388            rules: rules_a,
389            types: BTreeMap::new(),
390        };
391        let b = SliceInterface {
392            facts: BTreeMap::new(),
393            rules: BTreeMap::new(),
394            types: BTreeMap::new(),
395        };
396
397        assert_ne!(a, b);
398        let diffs = a.diff(&b);
399        assert!(diffs.iter().any(|d| d.contains("'z' removed")));
400    }
401}