Skip to main content

mangle_analysis/
stratification.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::PredicateSet;
16use fxhash::{FxHashMap, FxHashSet};
17use mangle_ast as ast;
18use mangle_ast::Arena;
19use std::fmt;
20
21/// Represents a Mangle program consisting of logic rules and declarations.
22///
23/// `Program` wraps the AST and provides the necessary methods to identify
24/// predicates and their dependencies. It is the primary input for the stratification algorithm.
25///
26/// It distinguishes between *extensional* predicates (stored facts) and *intensional*
27/// predicates (derived rules).
28#[derive(Clone)]
29pub struct Program<'p> {
30    pub arena: &'p Arena,
31    pub ext_preds: Vec<ast::PredicateIndex>,
32    pub rules: FxHashMap<ast::PredicateIndex, Vec<&'p ast::Clause<'p>>>,
33}
34
35impl<'p> fmt::Debug for Program<'p> {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.debug_struct("Program")
38            .field("ext_preds", &self.ext_preds)
39            .field("rules", &self.rules)
40            .finish()
41    }
42}
43
44/// A program that has been successfully stratified.
45///
46/// Contains the original program and the computed strata.
47/// This structure is used to guide the execution order of the IR.
48///
49/// Stratification ensures that if a predicate `p` depends negatively on `q`,
50/// then `q` is evaluated in an earlier stratum than `p`. This allows for
51/// the correct evaluation of negation and aggregation (semi-naive evaluation).
52#[derive(Clone)]
53pub struct StratifiedProgram<'p> {
54    program: Program<'p>,
55    strata: Vec<PredicateSet>,
56}
57
58impl<'p> fmt::Debug for StratifiedProgram<'p> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("StratifiedProgram")
61            .field("program", &self.program)
62            .field("strata", &self.strata)
63            .finish()
64    }
65}
66
67type EdgeMap = FxHashMap<ast::PredicateIndex, bool>;
68type DepGraph = FxHashMap<ast::PredicateIndex, EdgeMap>;
69type Nodeset = FxHashSet<ast::PredicateIndex>;
70
71impl<'p> Program<'p> {
72    pub fn new(arena: &'p Arena) -> Self {
73        Self {
74            arena,
75            ext_preds: Vec::new(),
76            rules: FxHashMap::default(),
77        }
78    }
79
80    pub fn add_clause<'src>(&mut self, src: &'src Arena, clause: &'src ast::Clause) {
81        let clause = self.arena.copy_clause(src, clause);
82        let sym = clause.head.sym;
83        use std::collections::hash_map::Entry;
84        match self.rules.entry(sym) {
85            Entry::Occupied(mut v) => v.get_mut().push(clause),
86            Entry::Vacant(v) => {
87                v.insert(vec![clause]);
88            }
89        }
90    }
91
92    /// Returns the AST Arena containing the program data.
93    pub fn arena(&'p self) -> &'p ast::Arena {
94        self.arena
95    }
96
97    /// Returns predicates for extensional DB (stored facts).
98    pub fn extensional_preds(&'p self) -> PredicateSet {
99        let mut set = FxHashSet::default();
100        set.extend(&self.ext_preds);
101        set
102    }
103
104    /// Returns predicates for intensional DB (derived rules).
105    pub fn intensional_preds(&'p self) -> PredicateSet {
106        let mut set = FxHashSet::default();
107        set.extend(self.rules.keys());
108        set
109    }
110
111    /// Maps predicates of intensional DB to their defining rules.
112    pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
113        self.rules.get(&sym).unwrap().iter().copied()
114    }
115
116    /// Analyzes the program's dependency graph and attempts to stratify it.
117    ///
118    /// Stratification partitions the predicates into ordered layers (strata).
119    /// This is essential for evaluating programs with negation or aggregation,
120    /// ensuring that dependencies are fully evaluated before they are used.
121    ///
122    /// Returns a `StratifiedProgram` on success, or an error if the program
123    /// contains unstratifiable cycles (e.g., negation cycles).
124    pub fn stratify(self) -> Result<StratifiedProgram<'p>, String> {
125        let dep = make_dep_graph(&self);
126        let mut strata = dep.sccs();
127
128        let mut pred_to_stratum: FxHashMap<ast::PredicateIndex, usize> = FxHashMap::default();
129
130        for (i, c) in strata.iter().enumerate() {
131            for sym in c {
132                pred_to_stratum.insert(*sym, i);
133            }
134            for sym in c {
135                if let Some(edges) = dep.get(sym) {
136                    for (dest, negated) in edges {
137                        if !*negated {
138                            continue;
139                        }
140                        let dest_stratum = pred_to_stratum.get(dest);
141                        if let Some(dest_stratum) = dest_stratum
142                            && *dest_stratum == i
143                        {
144                            return Err("program cannot be stratified".to_string());
145                        }
146                    }
147                }
148            }
149        }
150        dep.sort_result(&mut strata, pred_to_stratum);
151        let stratified = StratifiedProgram {
152            program: self,
153            strata: strata.into_iter().collect(),
154        };
155        Ok(stratified)
156    }
157}
158
159impl<'p> StratifiedProgram<'p> {
160    /// Returns the AST Arena containing the program data.
161    pub fn arena(&'p self) -> &'p ast::Arena {
162        self.program.arena()
163    }
164
165    /// Returns predicates for extensional DB (stored facts).
166    pub fn extensional_preds(&'p self) -> PredicateSet {
167        self.program.extensional_preds()
168    }
169
170    /// Returns predicates for intensional DB (derived rules).
171    pub fn intensional_preds(&'p self) -> PredicateSet {
172        self.program.intensional_preds()
173    }
174
175    /// Maps predicates of intensional DB to their defining rules.
176    pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
177        self.program.rules(sym)
178    }
179
180    /// Returns an iterator of strata, in dependency order.
181    /// Each stratum is a set of mutually recursive predicates that can be evaluated together.
182    pub fn strata(&'p self) -> Vec<PredicateSet> {
183        self.strata.to_vec()
184    }
185
186    /// Returns the stratum index for a given predicate symbol.
187    /// Returns `None` if the predicate is not part of the stratified program (e.g. it's EDB).
188    pub fn pred_to_index(&'p self, sym: ast::PredicateIndex) -> Option<usize> {
189        self.strata.iter().position(|x| x.contains(&sym))
190    }
191}
192
193fn make_dep_graph<'p>(program: &Program<'p>) -> DepGraph {
194    let mut dep: DepGraph = FxHashMap::default();
195
196    for (s, rule) in program.rules.iter() {
197        dep.init_node(*s);
198        for clause in rule.iter() {
199            for premise in clause.premises.iter() {
200                match premise {
201                    ast::Term::Atom(atom_pred) => {
202                        if !program.extensional_preds().contains(&atom_pred.sym) {
203                            if clause.transform.is_empty() || clause.transform[0].var.is_some() {
204                                dep.add_edge(*s, atom_pred.sym, false);
205                            } else {
206                                dep.add_edge(*s, atom_pred.sym, true);
207                            }
208                        }
209                    }
210                    ast::Term::NegAtom(atom_pred) => {
211                        if !program.extensional_preds().contains(&atom_pred.sym) {
212                            dep.add_edge(*s, atom_pred.sym, true);
213                        }
214                    }
215                    _ => {}
216                }
217            }
218        }
219    }
220    dep
221}
222
223fn apply_permutation_cycle_rotate<T: Default>(arr: &mut [T], permutation: &[usize]) {
224    let n = arr.len();
225    if n == 0 {
226        return;
227    }
228    let mut visited = vec![false; n];
229    for i in 0..n {
230        if !visited[i] {
231            let mut current_idx = i;
232            if permutation[current_idx] == i {
233                visited[i] = true;
234                continue;
235            }
236            let mut current_val = std::mem::take(&mut arr[i]);
237            loop {
238                let target_idx = permutation[current_idx];
239                visited[current_idx] = true;
240                let next_val = std::mem::replace(&mut arr[target_idx], current_val);
241                current_val = next_val;
242                current_idx = target_idx;
243                if current_idx == i {
244                    break;
245                }
246            }
247        }
248    }
249}
250
251trait DepGraphExt {
252    fn init_node(&mut self, src: ast::PredicateIndex);
253    fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool);
254    fn transpose(&self) -> DepGraph;
255    fn sccs(&self) -> Vec<Nodeset>;
256    fn sort_result(
257        &self,
258        strata: &mut Vec<Nodeset>,
259        pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
260    ) -> FxHashMap<ast::PredicateIndex, usize>;
261}
262
263impl DepGraphExt for DepGraph {
264    fn init_node(&mut self, src: ast::PredicateIndex) {
265        self.entry(src).or_default();
266    }
267
268    fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool) {
269        let edges = self.entry(src).or_default();
270        if negated {
271            edges.insert(dest, negated);
272            return;
273        }
274        if edges.get(&dest).is_none() || !edges[&dest] {
275            edges.insert(dest, false);
276        }
277    }
278
279    fn transpose(&self) -> DepGraph {
280        let mut rev: DepGraph = FxHashMap::default();
281        for (src, edges) in self.iter() {
282            for (dest, negated) in edges.iter() {
283                rev.init_node(*dest);
284                rev.add_edge(*dest, *src, *negated);
285            }
286        }
287        rev
288    }
289
290    fn sccs(&self) -> Vec<Nodeset> {
291        let mut s: Vec<ast::PredicateIndex> = Vec::new();
292        let mut seen: Nodeset = FxHashSet::default();
293
294        fn visit(
295            node: ast::PredicateIndex,
296            graph: &DepGraph,
297            s: &mut Vec<ast::PredicateIndex>,
298            seen: &mut Nodeset,
299        ) {
300            if !seen.contains(&node) {
301                seen.insert(node);
302                if let Some(edges) = graph.get(&node) {
303                    for &neighbor in edges.keys() {
304                        visit(neighbor, graph, s, seen);
305                    }
306                }
307                s.push(node);
308            }
309        }
310
311        for (node, _) in self.iter() {
312            visit(*node, self, &mut s, &mut seen);
313        }
314
315        let rev = self.transpose();
316        let mut seen: Nodeset = FxHashSet::default();
317        fn rvisit(
318            node: ast::PredicateIndex,
319            rev: &DepGraph,
320            scc: &mut Nodeset,
321            seen: &mut Nodeset,
322        ) {
323            if !seen.contains(&node) {
324                seen.insert(node);
325                scc.insert(node);
326                if let Some(edges) = rev.get(&node) {
327                    for &e in edges.keys() {
328                        rvisit(e, rev, scc, seen);
329                    }
330                }
331            }
332        }
333        let mut sccs: Vec<Nodeset> = Vec::new();
334        while let Some(top) = s.pop() {
335            if !seen.contains(&top) {
336                let mut scc: Nodeset = FxHashSet::default();
337                rvisit(top, &rev, &mut scc, &mut seen);
338                if !scc.is_empty() {
339                    sccs.push(scc);
340                }
341            }
342        }
343        sccs
344    }
345
346    fn sort_result(
347        &self,
348        strata: &mut Vec<Nodeset>,
349        pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
350    ) -> FxHashMap<ast::PredicateIndex, usize> {
351        let mut sorted_indices: Vec<usize> = Vec::new();
352        let mut seen: FxHashSet<usize> = FxHashSet::default();
353        let num_strata = strata.len();
354
355        fn visit_stratum(
356            index: usize,
357            dep: &DepGraph,
358            strata: &Vec<Nodeset>,
359            pred_to_stratum_map: &FxHashMap<ast::PredicateIndex, usize>,
360            seen: &mut FxHashSet<usize>,
361            sorted_indices: &mut Vec<usize>,
362        ) {
363            if seen.contains(&index) {
364                return;
365            }
366            seen.insert(index);
367
368            if let Some(scc) = strata.get(index) {
369                for sym in scc {
370                    if let Some(edges) = dep.get(sym) {
371                        for d in edges.keys() {
372                            if let Some(&dep_stratum_index) = pred_to_stratum_map.get(d) {
373                                visit_stratum(
374                                    dep_stratum_index,
375                                    dep,
376                                    strata,
377                                    pred_to_stratum_map,
378                                    seen,
379                                    sorted_indices,
380                                );
381                            }
382                        }
383                    }
384                }
385            }
386            sorted_indices.push(index);
387        }
388
389        for i in 0..num_strata {
390            visit_stratum(
391                i,
392                self,
393                strata,
394                &pred_to_stratum_map,
395                &mut seen,
396                &mut sorted_indices,
397            );
398        }
399
400        let mut permutation = vec![0; num_strata];
401        let mut old_to_new_map: FxHashMap<usize, usize> = FxHashMap::default();
402        for new_idx in 0..num_strata {
403            let old_idx = sorted_indices[new_idx];
404            permutation[old_idx] = new_idx;
405            old_to_new_map.insert(old_idx, new_idx);
406        }
407
408        apply_permutation_cycle_rotate(strata, &permutation);
409
410        let mut new_pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize> =
411            FxHashMap::default();
412        for (sym, &old_idx) in pred_to_stratum_map.iter() {
413            if let Some(&new_idx) = old_to_new_map.get(&old_idx) {
414                new_pred_to_stratum_map.insert(*sym, new_idx);
415            }
416        }
417        new_pred_to_stratum_map
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use mangle_parse::Parser;
425
426    #[test]
427    fn test_stratification_success() {
428        let arena = Arena::new_with_global_interner();
429        let source = r#"
430            p(1).
431            q(X) :- p(X).
432            r(X) :- q(X), !s(X).
433            s(2).
434        "#;
435        let mut parser = Parser::new(&arena, source.as_bytes(), "test");
436        parser.next_token().unwrap();
437        let unit = parser.parse_unit().unwrap();
438
439        let mut program = Program::new(&arena);
440        for clause in unit.clauses {
441            program.add_clause(&arena, clause);
442        }
443
444        let stratified = program.stratify().expect("should be stratifiable");
445
446        // Helper to check relative order
447        let get_stratum = |name: &str| -> Option<usize> {
448            let name_idx = arena.lookup_opt(name)?;
449            let pred_idx = arena.lookup_predicate_sym(name_idx)?;
450            stratified.pred_to_index(pred_idx)
451        };
452
453        let s_idx = get_stratum("s");
454        let r_idx = get_stratum("r");
455        let q_idx = get_stratum("q");
456        let p_idx = get_stratum("p");
457
458        assert!(s_idx.is_some());
459        assert!(r_idx.is_some());
460        assert!(q_idx.is_some());
461        assert!(p_idx.is_some());
462
463        // r depends negatively on s, so r > s
464        assert!(r_idx.unwrap() > s_idx.unwrap(), "r should be higher than s");
465
466        // q depends on p, so q >= p
467        assert!(q_idx.unwrap() >= p_idx.unwrap(), "q should be >= p");
468
469        // r depends on q, so r >= q
470        assert!(r_idx.unwrap() >= q_idx.unwrap(), "r should be >= q");
471    }
472
473    #[test]
474    fn test_stratification_cycle() {
475        let arena = Arena::new_with_global_interner();
476        let source = "p(X) :- !p(X).";
477        let mut parser = Parser::new(&arena, source.as_bytes(), "test");
478        parser.next_token().unwrap();
479        let unit = parser.parse_unit().unwrap();
480
481        let mut program = Program::new(&arena);
482        for clause in unit.clauses {
483            program.add_clause(&arena, clause);
484        }
485
486        let res = program.stratify();
487        assert!(res.is_err(), "should detect negation cycle");
488    }
489}