Skip to main content

xlog_logic/hypergraph/
inference.rs

1//! Transitive type inference across SCC predicates.
2//!
3//! Closes the PR 5 policy gap: where a join-key vertex was anchored
4//! only through SCC-recursive atoms, the typed gate previously left
5//! it untyped under "unknown ≠ unsupported." This module propagates
6//! types through the rule graph — body atoms type variables, head
7//! atoms back-propagate to head-predicate columns, iterate to
8//! fixpoint — so the typed gate has full type information when it
9//! consults [`super::analyze_typed`].
10//!
11//! ## Where inference is engaged
12//!
13//! Only the **group-aware** typed entry points engage inference:
14//!
15//! * [`super::evaluate_scc_fixpoint_typed`] runs inference once at
16//!   entry, then types each rule's body using the inferred schemas
17//!   plus `base_relations`.
18//! * [`super::evaluate_fixpoint_typed`] treats `target_predicate`
19//!   as a single-element rule group and runs the same inference.
20//!
21//! The single-rule entry points retain the base-only typing policy
22//! because they have no SCC structure to propagate over:
23//!
24//! * [`super::evaluate_rule_typed`] takes one rule.
25//! * [`super::plan_rule`] / [`super::plan_rules`] plan per-rule.
26//!
27//! Callers that want SCC-aware planning should drive
28//! [`super::evaluate_scc_fixpoint_typed`] directly or build their
29//! own inference pass via [`infer_scc_predicate_schemas`].
30//!
31//! ## Conflict layering
32//!
33//! Inference detects only **back-propagation conflicts**: e.g.,
34//! predicate `p`'s column 0 is `U32` from rule A's head and
35//! `Symbol` from rule B's head → [`InferenceError::ConflictingPredicateColumnType`].
36//! Within-rule body conflicts (variable `X` typed `U32` in one body
37//! atom and `Symbol` in another) stay in the existing
38//! [`super::typed`] flow and surface as
39//! [`super::RefEvalError::ConflictingVariableType`]. Each conflict
40//! type is detected at exactly one layer.
41//!
42//! ## Cyclic-only predicates
43//!
44//! When an SCC has no base anchor anywhere (e.g., `a(X) :- b(X),
45//! b(X) :- a(X)` with no rule referencing `base_relations`), every
46//! column converges to `None`. The typed gate must NOT reject such
47//! rules: the policy narrows from "unknown ≠ unsupported" to
48//! "unknowable-after-inference ≠ unsupported." Locked by
49//! `cyclic_only_predicate_still_passes_typed_gate_locked_policy`.
50//!
51//! ## Strict-correctness behavior change
52//!
53//! Fixtures whose base-relation schemas disagreed but whose actual
54//! rows happened to agree at runtime were previously silent (the
55//! typed gate types each body atom independently). They now surface
56//! as [`InferenceError::ConflictingPredicateColumnType`] when
57//! back-propagating to a head predicate. That is a strict
58//! correctness win, not a regression — fixtures with internally
59//! contradictory schemas are now caught before evaluation rather
60//! than silently corrupting downstream comparisons.
61
62use super::reference::RefRelationStore;
63use crate::ast::{BodyLiteral, Rule, Term};
64use std::collections::BTreeMap;
65use xlog_core::ScalarType;
66
67/// Errors surfaced by [`infer_scc_predicate_schemas`].
68#[derive(Debug, Clone, PartialEq)]
69pub enum InferenceError {
70    /// Two rules contributing to the same head predicate disagree
71    /// on the type of the same column. The first rule that types
72    /// the column wins `first_*`; the rule that disagrees wins
73    /// `second_*`.
74    ConflictingPredicateColumnType {
75        /// Head predicate name where the conflict was detected.
76        predicate: String,
77        /// 0-based column index where types disagree.
78        column: usize,
79        /// Rule index (within the predicate's rule group) that
80        /// first typed the column.
81        first_rule_index: usize,
82        /// Type derived from the first rule's body for the head
83        /// variable at this column.
84        first_type: ScalarType,
85        /// Rule index (within the predicate's rule group) whose
86        /// derivation conflicts.
87        second_rule_index: usize,
88        /// Type derived from the conflicting rule's body for the
89        /// head variable at this column.
90        second_type: ScalarType,
91    },
92}
93
94/// Per-predicate inferred schema. `Vec` length equals the head
95/// arity; each element is `Some(t)` if inference established the
96/// column's type, or `None` if the column remains unknowable
97/// (e.g., cyclic-only predicate, or a head term whose body atoms
98/// don't type the corresponding variable).
99pub type InferredSchemas = BTreeMap<String, Vec<Option<ScalarType>>>;
100
101/// Infer per-predicate schemas for a rule group via constraint
102/// propagation through the rule graph.
103///
104/// Algorithm:
105///
106/// 1. Determine head arity per predicate from the first rule with
107///    a non-empty head. (Predicates whose every rule has an empty
108///    head are treated as 0-arity; in practice this is rare.)
109/// 2. Initialize each predicate's schema as `vec![None; arity]`.
110/// 3. Iterate: for each rule, compute a per-rule variable-to-type
111///    map by walking body atoms (typing vars from
112///    `base_relations` schemas first, then from currently-inferred
113///    SCC predicate schemas where columns are `Some`). Then
114///    back-propagate: for each `Term::Variable` in the head at
115///    column `i`, if the variable has a derived type, propose it
116///    as the type for `head_predicate.schema[i]`. Conflict if a
117///    column has been previously typed differently.
118/// 4. Stop when no schema column changes between iterations.
119///
120/// Within-rule body conflicts are NOT detected here; they are
121/// caught by the existing [`super::typed`] gate during its own
122/// per-rule type-derivation walk. See module docs for the
123/// conflict-layering split.
124pub fn infer_scc_predicate_schemas(
125    rules: &BTreeMap<String, Vec<Rule>>,
126    base_relations: &RefRelationStore,
127) -> Result<InferredSchemas, InferenceError> {
128    // Step 1+2: arity + initial schemas.
129    let mut schemas: InferredSchemas = BTreeMap::new();
130    for (predicate, group) in rules.iter() {
131        let arity = group
132            .iter()
133            .find(|r| !r.head.terms.is_empty())
134            .map(|r| r.head.terms.len())
135            .unwrap_or(0);
136        schemas.insert(predicate.clone(), vec![None; arity]);
137    }
138    // Track the rule index that first typed each column so the
139    // conflict report can name both contributors.
140    let mut origins: BTreeMap<(String, usize), usize> = BTreeMap::new();
141    // Inference is monotonic: every iteration that changes
142    // anything replaces a `None` with a `Some(_)`. The total
143    // number of column slots across all SCC predicates is the
144    // strict upper bound on iterations that produce change. We
145    // add 1 to allow for the final no-change iteration that
146    // detects convergence.
147    let total_columns: usize = schemas.values().map(|s| s.len()).sum();
148    let max_iterations = total_columns + 1;
149    let mut converged = false;
150    for _ in 0..max_iterations {
151        let mut changed = false;
152        for (predicate, group) in rules.iter() {
153            for (rule_index, rule) in group.iter().enumerate() {
154                let var_types = derive_rule_var_types(rule, base_relations, &schemas);
155                // Back-propagate from head terms to head-predicate
156                // columns.
157                for (col, term) in rule.head.terms.iter().enumerate() {
158                    let name = match term {
159                        Term::Variable(n) => n,
160                        // Head constants / aggregates / wildcards do
161                        // not constrain a column type via inference.
162                        // Their type would be locked by the value
163                        // itself at evaluation time.
164                        _ => continue,
165                    };
166                    let Some(&derived) = var_types.get(name) else {
167                        continue;
168                    };
169                    let schema = schemas
170                        .get_mut(predicate)
171                        .expect("predicate in initialized schemas");
172                    if col >= schema.len() {
173                        // Head arity drift across rules — let the
174                        // structural SCC fixpoint surface this as
175                        // HeadArityMismatch. Inference doesn't
176                        // pre-empt; just skip this column.
177                        continue;
178                    }
179                    match schema[col] {
180                        None => {
181                            schema[col] = Some(derived);
182                            origins.insert((predicate.clone(), col), rule_index);
183                            changed = true;
184                        }
185                        Some(existing) if existing == derived => {
186                            // Agreement — silent.
187                        }
188                        Some(existing) => {
189                            let first_rule_index =
190                                origins.get(&(predicate.clone(), col)).copied().unwrap_or(0);
191                            return Err(InferenceError::ConflictingPredicateColumnType {
192                                predicate: predicate.clone(),
193                                column: col,
194                                first_rule_index,
195                                first_type: existing,
196                                second_rule_index: rule_index,
197                                second_type: derived,
198                            });
199                        }
200                    }
201                }
202            }
203        }
204        if !changed {
205            converged = true;
206            break;
207        }
208    }
209    // Monotonic invariant: every iteration that changed something
210    // replaced a None with a Some(_). The bound `total_columns + 1`
211    // strictly exceeds the number of such iterations possible, so
212    // failing to converge here indicates a future code change has
213    // broken the monotonicity guarantee — a programmer error, not
214    // a data error.
215    debug_assert!(
216        converged,
217        "type inference failed to converge within {max_iterations} iterations \
218         (monotonicity invariant violated)"
219    );
220    Ok(schemas)
221}
222
223/// Derive the per-variable type map for a single rule, consulting
224/// both `base_relations` and currently-inferred SCC schemas.
225///
226/// Body conflicts (a variable typed two different ways across
227/// body atoms within this rule) are NOT surfaced here — that is
228/// the responsibility of [`super::typed::derive_vertex_types`],
229/// which the typed gate calls before evaluation. This helper is
230/// a *forward* propagation pass that prefers the first type seen
231/// (in source order) and silently skips later disagreements; the
232/// typed gate later catches the disagreement on the same rule
233/// using its own walk.
234fn derive_rule_var_types(
235    rule: &Rule,
236    base_relations: &RefRelationStore,
237    inferred: &InferredSchemas,
238) -> BTreeMap<String, ScalarType> {
239    let mut var_types: BTreeMap<String, ScalarType> = BTreeMap::new();
240    for literal in &rule.body {
241        let body_atom = match literal {
242            BodyLiteral::Positive(a) => a,
243            _ => continue,
244        };
245        let schema_opt: Option<&[Option<ScalarType>]> =
246            if let Some(rel) = base_relations.get(&body_atom.predicate) {
247                // Build a transient "all-Some" view of the base schema.
248                // We don't actually need to allocate — handle directly.
249                let limit = body_atom.terms.len().min(rel.schema.len());
250                for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
251                    if let Term::Variable(name) = term {
252                        var_types.entry(name.clone()).or_insert(rel.schema[pos]);
253                    }
254                }
255                None
256            } else {
257                inferred.get(&body_atom.predicate).map(|v| v.as_slice())
258            };
259        if let Some(schema) = schema_opt {
260            let limit = body_atom.terms.len().min(schema.len());
261            for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
262                if let Term::Variable(name) = term {
263                    if let Some(ty) = schema[pos] {
264                        var_types.entry(name.clone()).or_insert(ty);
265                    }
266                }
267            }
268        }
269    }
270    var_types
271}
272
273/// Build the typed-gate input map for a single rule using
274/// inferred SCC schemas alongside base relations.
275///
276/// Mirrors [`super::typed::derive_vertex_types`]'s contract — same
277/// conflict surface ([`super::RefEvalError::ConflictingVariableType`])
278/// — but consults `inferred_schemas` whenever a body atom's
279/// predicate is not in `base_relations`. Inferred columns marked
280/// `None` are treated identically to "predicate absent": they
281/// don't type the variable at that position.
282///
283/// Used by [`super::evaluate_scc_fixpoint_typed`] and
284/// [`super::evaluate_fixpoint_typed`] inside their per-rule typed
285/// gate to give [`super::analyze_typed`] full type information.
286pub(super) fn derive_vertex_types_with_inference(
287    rule: &Rule,
288    base_relations: &RefRelationStore,
289    inferred_schemas: &InferredSchemas,
290) -> Result<BTreeMap<String, ScalarType>, super::RefEvalError> {
291    /// First-recorded site for a variable; used to populate the
292    /// `ConflictingVariableType` report when a second body atom
293    /// types the variable differently.
294    struct FirstSite {
295        predicate: String,
296        position: usize,
297        ty: ScalarType,
298    }
299    let mut sites: BTreeMap<String, FirstSite> = BTreeMap::new();
300    for literal in &rule.body {
301        let body_atom = match literal {
302            BodyLiteral::Positive(a) => a,
303            _ => continue,
304        };
305        // Type each position. Base relation wins if both are
306        // present (cannot happen — `base_relations` and
307        // `inferred_schemas` keys are disjoint by construction in
308        // the typed evaluators).
309        let position_types: Vec<Option<ScalarType>> =
310            if let Some(rel) = base_relations.get(&body_atom.predicate) {
311                let limit = body_atom.terms.len().min(rel.schema.len());
312                let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
313                for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
314                    *slot = Some(rel.schema[pos_idx]);
315                }
316                v
317            } else if let Some(schema) = inferred_schemas.get(&body_atom.predicate) {
318                let limit = body_atom.terms.len().min(schema.len());
319                let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
320                for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
321                    *slot = schema[pos_idx];
322                }
323                v
324            } else {
325                continue; // predicate unknown, no type info
326            };
327        for (position, term) in body_atom.terms.iter().enumerate() {
328            let var_name = match term {
329                Term::Variable(name) => name.clone(),
330                _ => continue,
331            };
332            let Some(ty) = position_types[position] else {
333                continue;
334            };
335            match sites.get(&var_name) {
336                None => {
337                    sites.insert(
338                        var_name,
339                        FirstSite {
340                            predicate: body_atom.predicate.clone(),
341                            position,
342                            ty,
343                        },
344                    );
345                }
346                Some(prior) if prior.ty == ty => {
347                    // Agreeing repeat — silent.
348                }
349                Some(prior) => {
350                    return Err(super::RefEvalError::ConflictingVariableType {
351                        var: var_name,
352                        first_predicate: prior.predicate.clone(),
353                        first_position: prior.position,
354                        first_type: prior.ty,
355                        second_predicate: body_atom.predicate.clone(),
356                        second_position: position,
357                        second_type: ty,
358                    });
359                }
360            }
361        }
362    }
363    Ok(sites
364        .into_iter()
365        .map(|(name, site)| (name, site.ty))
366        .collect())
367}