Skip to main content

mangle_analysis/
stratification.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::PredicateSet;
16use fxhash::{FxHashMap, FxHashSet};
17use mangle_ast as ast;
18use mangle_ast::Arena;
19use std::fmt;
20
21/// Represents a Mangle program consisting of logic rules and declarations.
22///
23/// `Program` wraps the AST and provides the necessary methods to identify
24/// predicates and their dependencies. It is the primary input for the stratification algorithm.
25///
26/// It distinguishes between *extensional* predicates (stored facts) and *intensional*
27/// predicates (derived rules).
28#[derive(Clone)]
29pub struct Program<'p> {
30    pub arena: &'p Arena,
31    pub ext_preds: Vec<ast::PredicateIndex>,
32    pub rules: FxHashMap<ast::PredicateIndex, Vec<&'p ast::Clause<'p>>>,
33}
34
35impl<'p> fmt::Debug for Program<'p> {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.debug_struct("Program")
38            .field("ext_preds", &self.ext_preds)
39            .field("rules", &self.rules)
40            .finish()
41    }
42}
43
44/// A program that has been successfully stratified.
45///
46/// Contains the original program and the computed strata.
47/// This structure is used to guide the execution order of the IR.
48///
49/// Stratification ensures that if a predicate `p` depends negatively on `q`,
50/// then `q` is evaluated in an earlier stratum than `p`. This allows for
51/// the correct evaluation of negation and aggregation (semi-naive evaluation).
52#[derive(Clone)]
53pub struct StratifiedProgram<'p> {
54    program: Program<'p>,
55    strata: Vec<PredicateSet>,
56}
57
58impl<'p> fmt::Debug for StratifiedProgram<'p> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("StratifiedProgram")
61            .field("program", &self.program)
62            .field("strata", &self.strata)
63            .finish()
64    }
65}
66
67type EdgeMap = FxHashMap<ast::PredicateIndex, bool>;
68type DepGraph = FxHashMap<ast::PredicateIndex, EdgeMap>;
69type Nodeset = FxHashSet<ast::PredicateIndex>;
70
71impl<'p> Program<'p> {
72    pub fn new(arena: &'p Arena) -> Self {
73        Self {
74            arena,
75            ext_preds: Vec::new(),
76            rules: FxHashMap::default(),
77        }
78    }
79
80    pub fn add_clause<'src>(&mut self, src: &'src Arena, clause: &'src ast::Clause) {
81        let clause = self.arena.copy_clause(src, clause);
82        let sym = clause.head.sym;
83        use std::collections::hash_map::Entry;
84        match self.rules.entry(sym) {
85            Entry::Occupied(mut v) => v.get_mut().push(clause),
86            Entry::Vacant(v) => {
87                v.insert(vec![clause]);
88            }
89        }
90    }
91
92    /// Returns the AST Arena containing the program data.
93    pub fn arena(&'p self) -> &'p ast::Arena {
94        self.arena
95    }
96
97    /// Returns predicates for extensional DB (stored facts).
98    pub fn extensional_preds(&'p self) -> PredicateSet {
99        let mut set = FxHashSet::default();
100        set.extend(&self.ext_preds);
101        set
102    }
103
104    /// Returns predicates for intensional DB (derived rules).
105    pub fn intensional_preds(&'p self) -> PredicateSet {
106        let mut set = FxHashSet::default();
107        set.extend(self.rules.keys());
108        set
109    }
110
111    /// Maps predicates of intensional DB to their defining rules.
112    pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
113        self.rules.get(&sym).unwrap().iter().copied()
114    }
115
116    /// Analyzes the program's dependency graph and attempts to stratify it.
117    ///
118    /// Stratification partitions the predicates into ordered layers (strata).
119    /// This is essential for evaluating programs with negation or aggregation,
120    /// ensuring that dependencies are fully evaluated before they are used.
121    ///
122    /// Returns a `StratifiedProgram` on success, or an error if the program
123    /// contains unstratifiable cycles (e.g., negation cycles).
124    pub fn stratify(self) -> Result<StratifiedProgram<'p>, String> {
125        let dep = make_dep_graph(&self);
126        let mut strata = dep.sccs();
127
128        let mut pred_to_stratum: FxHashMap<ast::PredicateIndex, usize> = FxHashMap::default();
129
130        for (i, c) in strata.iter().enumerate() {
131            for sym in c {
132                pred_to_stratum.insert(*sym, i);
133            }
134            for sym in c {
135                if let Some(edges) = dep.get(sym) {
136                    for (dest, negated) in edges {
137                        if !*negated {
138                            continue;
139                        }
140                        let dest_stratum = pred_to_stratum.get(dest);
141                        if let Some(dest_stratum) = dest_stratum
142                            && *dest_stratum == i
143                        {
144                            return Err("program cannot be stratified".to_string());
145                        }
146                    }
147                }
148            }
149        }
150        dep.sort_result(&mut strata, pred_to_stratum);
151        let stratified = StratifiedProgram {
152            program: self,
153            strata: strata.into_iter().collect(),
154        };
155        Ok(stratified)
156    }
157}
158
159impl<'p> StratifiedProgram<'p> {
160    /// Returns the AST Arena containing the program data.
161    pub fn arena(&'p self) -> &'p ast::Arena {
162        self.program.arena()
163    }
164
165    /// Returns predicates for extensional DB (stored facts).
166    pub fn extensional_preds(&'p self) -> PredicateSet {
167        self.program.extensional_preds()
168    }
169
170    /// Returns predicates for intensional DB (derived rules).
171    pub fn intensional_preds(&'p self) -> PredicateSet {
172        self.program.intensional_preds()
173    }
174
175    /// Maps predicates of intensional DB to their defining rules.
176    pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
177        self.program.rules(sym)
178    }
179
180    /// Returns an iterator of strata, in dependency order.
181    /// Each stratum is a set of mutually recursive predicates that can be evaluated together.
182    pub fn strata(&'p self) -> Vec<PredicateSet> {
183        self.strata.to_vec()
184    }
185
186    /// Returns the stratum index for a given predicate symbol.
187    /// Returns `None` if the predicate is not part of the stratified program (e.g. it's EDB).
188    pub fn pred_to_index(&'p self, sym: ast::PredicateIndex) -> Option<usize> {
189        self.strata.iter().position(|x| x.contains(&sym))
190    }
191}
192
193fn make_dep_graph<'p>(program: &Program<'p>) -> DepGraph {
194    let mut dep: DepGraph = FxHashMap::default();
195
196    for (s, rule) in program.rules.iter() {
197        dep.init_node(*s);
198        for clause in rule.iter() {
199            for premise in clause.premises.iter() {
200                match premise {
201                    ast::Term::Atom(atom_pred) => {
202                        if !program.extensional_preds().contains(&atom_pred.sym) {
203                            if clause.transform.is_empty() || clause.transform[0].var.is_some() {
204                                dep.add_edge(*s, atom_pred.sym, false);
205                            } else {
206                                dep.add_edge(*s, atom_pred.sym, true);
207                            }
208                        }
209                    }
210                    ast::Term::NegAtom(atom_pred) => {
211                        if !program.extensional_preds().contains(&atom_pred.sym) {
212                            dep.add_edge(*s, atom_pred.sym, true);
213                        }
214                    }
215                    ast::Term::TemporalAtom(atom_pred, _) => {
216                        if !program.extensional_preds().contains(&atom_pred.sym) {
217                            if clause.transform.is_empty() || clause.transform[0].var.is_some() {
218                                dep.add_edge(*s, atom_pred.sym, false);
219                            } else {
220                                dep.add_edge(*s, atom_pred.sym, true);
221                            }
222                        }
223                    }
224                    _ => {}
225                }
226            }
227        }
228    }
229    dep
230}
231
232fn apply_permutation_cycle_rotate<T: Default>(arr: &mut [T], permutation: &[usize]) {
233    let n = arr.len();
234    if n == 0 {
235        return;
236    }
237    let mut visited = vec![false; n];
238    for i in 0..n {
239        if !visited[i] {
240            let mut current_idx = i;
241            if permutation[current_idx] == i {
242                visited[i] = true;
243                continue;
244            }
245            let mut current_val = std::mem::take(&mut arr[i]);
246            loop {
247                let target_idx = permutation[current_idx];
248                visited[current_idx] = true;
249                let next_val = std::mem::replace(&mut arr[target_idx], current_val);
250                current_val = next_val;
251                current_idx = target_idx;
252                if current_idx == i {
253                    break;
254                }
255            }
256        }
257    }
258}
259
260trait DepGraphExt {
261    fn init_node(&mut self, src: ast::PredicateIndex);
262    fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool);
263    fn transpose(&self) -> DepGraph;
264    fn sccs(&self) -> Vec<Nodeset>;
265    fn sort_result(
266        &self,
267        strata: &mut Vec<Nodeset>,
268        pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
269    ) -> FxHashMap<ast::PredicateIndex, usize>;
270}
271
272impl DepGraphExt for DepGraph {
273    fn init_node(&mut self, src: ast::PredicateIndex) {
274        self.entry(src).or_default();
275    }
276
277    fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool) {
278        let edges = self.entry(src).or_default();
279        if negated {
280            edges.insert(dest, negated);
281            return;
282        }
283        if edges.get(&dest).is_none() || !edges[&dest] {
284            edges.insert(dest, false);
285        }
286    }
287
288    fn transpose(&self) -> DepGraph {
289        let mut rev: DepGraph = FxHashMap::default();
290        for (src, edges) in self.iter() {
291            for (dest, negated) in edges.iter() {
292                rev.init_node(*dest);
293                rev.add_edge(*dest, *src, *negated);
294            }
295        }
296        rev
297    }
298
299    fn sccs(&self) -> Vec<Nodeset> {
300        let mut s: Vec<ast::PredicateIndex> = Vec::new();
301        let mut seen: Nodeset = FxHashSet::default();
302
303        fn visit(
304            node: ast::PredicateIndex,
305            graph: &DepGraph,
306            s: &mut Vec<ast::PredicateIndex>,
307            seen: &mut Nodeset,
308        ) {
309            if !seen.contains(&node) {
310                seen.insert(node);
311                if let Some(edges) = graph.get(&node) {
312                    for &neighbor in edges.keys() {
313                        visit(neighbor, graph, s, seen);
314                    }
315                }
316                s.push(node);
317            }
318        }
319
320        for (node, _) in self.iter() {
321            visit(*node, self, &mut s, &mut seen);
322        }
323
324        let rev = self.transpose();
325        let mut seen: Nodeset = FxHashSet::default();
326        fn rvisit(
327            node: ast::PredicateIndex,
328            rev: &DepGraph,
329            scc: &mut Nodeset,
330            seen: &mut Nodeset,
331        ) {
332            if !seen.contains(&node) {
333                seen.insert(node);
334                scc.insert(node);
335                if let Some(edges) = rev.get(&node) {
336                    for &e in edges.keys() {
337                        rvisit(e, rev, scc, seen);
338                    }
339                }
340            }
341        }
342        let mut sccs: Vec<Nodeset> = Vec::new();
343        while let Some(top) = s.pop() {
344            if !seen.contains(&top) {
345                let mut scc: Nodeset = FxHashSet::default();
346                rvisit(top, &rev, &mut scc, &mut seen);
347                if !scc.is_empty() {
348                    sccs.push(scc);
349                }
350            }
351        }
352        sccs
353    }
354
355    fn sort_result(
356        &self,
357        strata: &mut Vec<Nodeset>,
358        pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
359    ) -> FxHashMap<ast::PredicateIndex, usize> {
360        let mut sorted_indices: Vec<usize> = Vec::new();
361        let mut seen: FxHashSet<usize> = FxHashSet::default();
362        let num_strata = strata.len();
363
364        fn visit_stratum(
365            index: usize,
366            dep: &DepGraph,
367            strata: &Vec<Nodeset>,
368            pred_to_stratum_map: &FxHashMap<ast::PredicateIndex, usize>,
369            seen: &mut FxHashSet<usize>,
370            sorted_indices: &mut Vec<usize>,
371        ) {
372            if seen.contains(&index) {
373                return;
374            }
375            seen.insert(index);
376
377            if let Some(scc) = strata.get(index) {
378                for sym in scc {
379                    if let Some(edges) = dep.get(sym) {
380                        for d in edges.keys() {
381                            if let Some(&dep_stratum_index) = pred_to_stratum_map.get(d) {
382                                visit_stratum(
383                                    dep_stratum_index,
384                                    dep,
385                                    strata,
386                                    pred_to_stratum_map,
387                                    seen,
388                                    sorted_indices,
389                                );
390                            }
391                        }
392                    }
393                }
394            }
395            sorted_indices.push(index);
396        }
397
398        for i in 0..num_strata {
399            visit_stratum(
400                i,
401                self,
402                strata,
403                &pred_to_stratum_map,
404                &mut seen,
405                &mut sorted_indices,
406            );
407        }
408
409        let mut permutation = vec![0; num_strata];
410        let mut old_to_new_map: FxHashMap<usize, usize> = FxHashMap::default();
411        for new_idx in 0..num_strata {
412            let old_idx = sorted_indices[new_idx];
413            permutation[old_idx] = new_idx;
414            old_to_new_map.insert(old_idx, new_idx);
415        }
416
417        apply_permutation_cycle_rotate(strata, &permutation);
418
419        let mut new_pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize> =
420            FxHashMap::default();
421        for (sym, &old_idx) in pred_to_stratum_map.iter() {
422            if let Some(&new_idx) = old_to_new_map.get(&old_idx) {
423                new_pred_to_stratum_map.insert(*sym, new_idx);
424            }
425        }
426        new_pred_to_stratum_map
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use mangle_parse::Parser;
434
435    #[test]
436    fn test_stratification_success() {
437        let arena = Arena::new_with_global_interner();
438        let source = r#"
439            p(1).
440            q(X) :- p(X).
441            r(X) :- q(X), !s(X).
442            s(2).
443        "#;
444        let mut parser = Parser::new(&arena, source.as_bytes(), "test");
445        parser.next_token().unwrap();
446        let unit = parser.parse_unit().unwrap();
447
448        let mut program = Program::new(&arena);
449        for clause in unit.clauses {
450            program.add_clause(&arena, clause);
451        }
452
453        let stratified = program.stratify().expect("should be stratifiable");
454
455        // Helper to check relative order
456        let get_stratum = |name: &str| -> Option<usize> {
457            let name_idx = arena.lookup_opt(name)?;
458            let pred_idx = arena.lookup_predicate_sym(name_idx)?;
459            stratified.pred_to_index(pred_idx)
460        };
461
462        let s_idx = get_stratum("s");
463        let r_idx = get_stratum("r");
464        let q_idx = get_stratum("q");
465        let p_idx = get_stratum("p");
466
467        assert!(s_idx.is_some());
468        assert!(r_idx.is_some());
469        assert!(q_idx.is_some());
470        assert!(p_idx.is_some());
471
472        // r depends negatively on s, so r > s
473        assert!(r_idx.unwrap() > s_idx.unwrap(), "r should be higher than s");
474
475        // q depends on p, so q >= p
476        assert!(q_idx.unwrap() >= p_idx.unwrap(), "q should be >= p");
477
478        // r depends on q, so r >= q
479        assert!(r_idx.unwrap() >= q_idx.unwrap(), "r should be >= q");
480    }
481
482    #[test]
483    fn test_stratification_cycle() {
484        let arena = Arena::new_with_global_interner();
485        let source = "p(X) :- !p(X).";
486        let mut parser = Parser::new(&arena, source.as_bytes(), "test");
487        parser.next_token().unwrap();
488        let unit = parser.parse_unit().unwrap();
489
490        let mut program = Program::new(&arena);
491        for clause in unit.clauses {
492            program.add_clause(&arena, clause);
493        }
494
495        let res = program.stratify();
496        assert!(res.is_err(), "should detect negation cycle");
497    }
498}