Skip to main content

xlog_logic/hypergraph/
scc.rs

1//! Multi-predicate SCC fixpoint evaluator.
2//!
3//! Extends the single-target [`super::evaluate_fixpoint`] to a
4//! mutually-recursive SCC: rules grouped by their target predicate,
5//! evaluated jointly so each predicate's body can reference any
6//! other predicate in the SCC. Naive evaluation: one iteration runs
7//! every rule for every predicate; convergence is the global
8//! "no relation grew this iteration" check.
9//!
10//! Pure-Rust, deterministic, set-semantics. Built to be the
11//! recursive-multi-predicate WCOJ correctness oracle for PR 5+
12//! mixed-execution kernels. Not optimized — semi-naive
13//! delta-driven SCC fixpoint is a separate concern.
14//!
15//! ## Determinism
16//!
17//! * Rules grouped under each predicate are evaluated in input
18//!   order (slice order). Rule order does NOT affect the result
19//!   (locked by test) — it's an efficiency knob, not a semantic
20//!   one.
21//! * Predicates are iterated in [`BTreeMap`]'s sorted-by-key
22//!   order. Predicate order does NOT affect the result (locked
23//!   by test).
24//! * Output rows for each predicate are sorted lexicographically
25//!   and deduplicated.
26//!
27//! ## Schema management
28//!
29//! Per-predicate schemas are frozen from the first iteration
30//! that produces non-empty rows for that predicate. Subsequent
31//! iterations validate every newly-produced row's variants against
32//! the frozen schema; mismatches surface as
33//! [`SccFixpointError::InconsistentHeadValueTypes`] before the
34//! row is unioned. Row arity drift surfaces as
35//! [`SccFixpointError::HeadArityMismatch`] at function entry —
36//! checked once per predicate group across its rules.
37
38use super::{
39    evaluate_rule, FixpointConfig, RefEvalError, RefRelation, RefRelationStore, RefValue,
40    VariableOrder,
41};
42use crate::ast::Rule;
43use std::collections::BTreeMap;
44use xlog_core::ScalarType;
45
46/// Errors surfaced by [`evaluate_scc_fixpoint`].
47#[derive(Debug, Clone, PartialEq)]
48pub enum SccFixpointError {
49    /// A rule grouped under predicate `key` heads a different
50    /// predicate. The grouping invariant — every rule's head
51    /// predicate equals its `BTreeMap` key — is checked at
52    /// function entry.
53    RuleHeadPredicateMismatch {
54        /// `BTreeMap` key under which the rule was grouped.
55        group_key: String,
56        /// Index of the rule within that group.
57        rule_index: usize,
58        /// Head predicate observed on the rule.
59        observed: String,
60    },
61    /// Two rules grouped under the same predicate disagree on
62    /// head arity.
63    HeadArityMismatch {
64        /// Predicate name.
65        predicate: String,
66        /// Index of the offending rule within its group.
67        rule_index: usize,
68        /// Head arity observed on this rule.
69        observed_arity: usize,
70        /// Head arity established by the first non-empty-head
71        /// rule in the same group.
72        expected_arity: usize,
73    },
74    /// A predicate's rules produced rows whose [`RefValue`]
75    /// variants disagree across iterations or across rules
76    /// within an iteration. Detected by validating each newly
77    /// produced row's variant tuple against the predicate's
78    /// frozen schema before unioning.
79    InconsistentHeadValueTypes {
80        /// Predicate name.
81        predicate: String,
82        /// Column index where the mismatch was first observed.
83        column: usize,
84        /// Schema-frozen scalar type at that column.
85        expected: ScalarType,
86        /// String description of the offending value.
87        got: String,
88    },
89    /// `target_predicate` was already present in `base_relations`.
90    /// SCC predicates are constructed by the fixpoint; allowing
91    /// `base_relations` to seed any of them would silently shadow
92    /// the caller's seed.
93    PredicateInBaseRelations {
94        /// The SCC predicate name as supplied.
95        name: String,
96    },
97    /// A rule failed evaluation. Wraps the per-rule error with
98    /// (predicate, rule_index) so the caller can pinpoint which
99    /// rule of which group failed.
100    RuleEval {
101        /// Predicate group of the offending rule.
102        predicate: String,
103        /// Index within that group.
104        rule_index: usize,
105        /// The wrapped per-rule error.
106        source: RefEvalError,
107    },
108    /// The SCC fixpoint did not converge within
109    /// [`FixpointConfig::max_iterations`].
110    MaxIterationsExceeded {
111        /// The configured cap.
112        limit: usize,
113        /// Number of predicates in the SCC.
114        predicate_count: usize,
115        /// Total derived rows summed across all SCC predicates
116        /// at the cap.
117        total_observed_rows: usize,
118    },
119    /// At least one predicate had no rules with a non-empty head,
120    /// so its arity could not be inferred.
121    SchemaIndeterminable {
122        /// Predicate name whose arity could not be inferred.
123        predicate: String,
124    },
125    /// `max_iterations` was zero. Must be ≥ 1.
126    InvalidMaxIterations,
127}
128
129/// Evaluate a mutually-recursive SCC of predicates to a fixpoint.
130///
131/// `rules` maps each predicate name to the list of rules deriving
132/// it. Every rule's head predicate must equal its group key
133/// (validated at entry). `base_relations` carries non-SCC
134/// predicates referenced in rule bodies (e.g. EDB facts). The
135/// SCC predicates must NOT appear in `base_relations`.
136///
137/// Returns a [`RefRelationStore`] whose keys are exactly the
138/// keys of `rules`, each mapped to the converged relation. Set
139/// semantics: rows sorted lexicographically, deduplicated.
140#[allow(clippy::result_large_err)]
141pub fn evaluate_scc_fixpoint(
142    rules: &BTreeMap<String, Vec<Rule>>,
143    base_relations: &RefRelationStore,
144    order: &dyn VariableOrder,
145    config: &FixpointConfig,
146) -> Result<RefRelationStore, SccFixpointError> {
147    if config.max_iterations == 0 {
148        return Err(SccFixpointError::InvalidMaxIterations);
149    }
150
151    // Entry validation: per-predicate group invariants.
152    let mut arities: BTreeMap<String, usize> = BTreeMap::new();
153    for (predicate, group) in rules.iter() {
154        if base_relations.contains_key(predicate) {
155            return Err(SccFixpointError::PredicateInBaseRelations {
156                name: predicate.clone(),
157            });
158        }
159        for (idx, rule) in group.iter().enumerate() {
160            if rule.head.predicate != *predicate {
161                return Err(SccFixpointError::RuleHeadPredicateMismatch {
162                    group_key: predicate.clone(),
163                    rule_index: idx,
164                    observed: rule.head.predicate.clone(),
165                });
166            }
167        }
168        // Establish arity from the first non-empty-head rule.
169        let arity = group
170            .iter()
171            .find(|r| !r.head.terms.is_empty())
172            .map(|r| r.head.terms.len())
173            .ok_or_else(|| SccFixpointError::SchemaIndeterminable {
174                predicate: predicate.clone(),
175            })?;
176        for (idx, rule) in group.iter().enumerate() {
177            if rule.head.terms.is_empty() {
178                continue;
179            }
180            if rule.head.terms.len() != arity {
181                return Err(SccFixpointError::HeadArityMismatch {
182                    predicate: predicate.clone(),
183                    rule_index: idx,
184                    observed_arity: rule.head.terms.len(),
185                    expected_arity: arity,
186                });
187            }
188        }
189        arities.insert(predicate.clone(), arity);
190    }
191
192    // Per-predicate frozen schema and derived rows. Schemas seed
193    // as `[U32; arity]` (matches PR 3); first non-empty iter
194    // freezes from row variants.
195    let mut frozen_schemas: BTreeMap<String, Option<Vec<ScalarType>>> = BTreeMap::new();
196    let mut derived: BTreeMap<String, Vec<Vec<RefValue>>> = BTreeMap::new();
197    for predicate in rules.keys() {
198        frozen_schemas.insert(predicate.clone(), None);
199        derived.insert(predicate.clone(), Vec::new());
200    }
201
202    for _iter in 0..config.max_iterations {
203        // Build the per-iter store: base ∪ {predicate → derived}.
204        let mut store = base_relations.clone();
205        for (predicate, rows) in derived.iter() {
206            let schema = frozen_schemas
207                .get(predicate)
208                .and_then(|s| s.clone())
209                .unwrap_or_else(|| vec![ScalarType::U32; arities[predicate]]);
210            store.insert(
211                predicate.clone(),
212                RefRelation {
213                    schema,
214                    rows: rows.clone(),
215                },
216            );
217        }
218
219        // For every predicate in sorted-key order, run every rule
220        // and union new tuples into a per-predicate scratch buffer.
221        let mut next: BTreeMap<String, Vec<Vec<RefValue>>> = derived.clone();
222        for (predicate, group) in rules.iter() {
223            let mut produced: Vec<Vec<RefValue>> = Vec::new();
224            for (rule_index, rule) in group.iter().enumerate() {
225                let rows =
226                    evaluate_rule(rule, &store, order).map_err(|e| SccFixpointError::RuleEval {
227                        predicate: predicate.clone(),
228                        rule_index,
229                        source: e,
230                    })?;
231                produced.extend(rows);
232            }
233            // Freeze the schema from the first iteration that
234            // produces non-empty rows for THIS predicate; once
235            // frozen, validate every newly-produced row's
236            // variant tuple.
237            let frozen_entry = frozen_schemas.get_mut(predicate).expect("inserted above");
238            if frozen_entry.is_none() {
239                if let Some(first) = produced.first() {
240                    *frozen_entry = Some(infer_schema(first));
241                }
242            }
243            if let Some(schema) = frozen_entry.as_ref() {
244                for row in &produced {
245                    if let Some((column, expected, got)) = first_type_mismatch(row, schema) {
246                        return Err(SccFixpointError::InconsistentHeadValueTypes {
247                            predicate: predicate.clone(),
248                            column,
249                            expected,
250                            got,
251                        });
252                    }
253                }
254            }
255            let target = next.get_mut(predicate).expect("predicate present");
256            target.extend(produced);
257            target.sort();
258            target.dedup();
259        }
260
261        if next == derived {
262            // Converged. Build the final RefRelationStore.
263            let mut out: RefRelationStore = BTreeMap::new();
264            for (predicate, rows) in derived.into_iter() {
265                let schema = frozen_schemas
266                    .get(&predicate)
267                    .and_then(|s| s.clone())
268                    .unwrap_or_else(|| vec![ScalarType::U32; arities[&predicate]]);
269                out.insert(predicate, RefRelation { schema, rows });
270            }
271            return Ok(out);
272        }
273        derived = next;
274    }
275
276    let total: usize = derived.values().map(|v| v.len()).sum();
277    Err(SccFixpointError::MaxIterationsExceeded {
278        limit: config.max_iterations,
279        predicate_count: rules.len(),
280        total_observed_rows: total,
281    })
282}
283
284fn infer_schema(row: &[RefValue]) -> Vec<ScalarType> {
285    row.iter()
286        .map(|v| match v {
287            RefValue::U32(_) => ScalarType::U32,
288            RefValue::U64(_) => ScalarType::U64,
289            RefValue::I32(_) => ScalarType::I32,
290            RefValue::I64(_) => ScalarType::I64,
291            RefValue::Bool(_) => ScalarType::Bool,
292            RefValue::Symbol(_) => ScalarType::Symbol,
293        })
294        .collect()
295}
296
297/// Return the first column where `row`'s [`RefValue`] variant does
298/// not match `schema`'s [`ScalarType`]. Mirrors PR 2's row-level
299/// `ref_value_matches_scalar_type`, applied at row produced (not
300/// row stored) time so caller-facing errors point at rule output,
301/// not at downstream relation validation.
302///
303/// Row arity is enforced upstream by [`SccFixpointError::HeadArityMismatch`]
304/// at function entry; if a row of mismatched length somehow escapes
305/// that check, it surfaces downstream as
306/// [`RefEvalError::RelationRowArityMismatch`] from PR 2's validation
307/// — which is honest about what it found. We do NOT add a synthetic
308/// arity-mismatch arm here with a placeholder `ScalarType` (that
309/// pattern was the silent-skip bug class PR 2's validator made us
310/// fix in `RefEvalError::ConstantTypeMismatch`).
311fn first_type_mismatch(
312    row: &[RefValue],
313    schema: &[ScalarType],
314) -> Option<(usize, ScalarType, String)> {
315    for (i, (val, ty)) in row.iter().zip(schema.iter()).enumerate() {
316        let ok = matches!(
317            (val, ty),
318            (RefValue::U32(_), ScalarType::U32)
319                | (RefValue::U64(_), ScalarType::U64)
320                | (RefValue::I32(_), ScalarType::I32)
321                | (RefValue::I64(_), ScalarType::I64)
322                | (RefValue::Bool(_), ScalarType::Bool)
323                | (RefValue::Symbol(_), ScalarType::Symbol)
324        );
325        if !ok {
326            return Some((i, *ty, format!("{val:?}")));
327        }
328    }
329    None
330}