1use crate::PredicateSet;
16use fxhash::{FxHashMap, FxHashSet};
17use mangle_ast as ast;
18use mangle_ast::Arena;
19use std::fmt;
20
21#[derive(Clone)]
29pub struct Program<'p> {
30 pub arena: &'p Arena,
31 pub ext_preds: Vec<ast::PredicateIndex>,
32 pub rules: FxHashMap<ast::PredicateIndex, Vec<&'p ast::Clause<'p>>>,
33}
34
35impl<'p> fmt::Debug for Program<'p> {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("Program")
38 .field("ext_preds", &self.ext_preds)
39 .field("rules", &self.rules)
40 .finish()
41 }
42}
43
44#[derive(Clone)]
53pub struct StratifiedProgram<'p> {
54 program: Program<'p>,
55 strata: Vec<PredicateSet>,
56}
57
58impl<'p> fmt::Debug for StratifiedProgram<'p> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("StratifiedProgram")
61 .field("program", &self.program)
62 .field("strata", &self.strata)
63 .finish()
64 }
65}
66
67type EdgeMap = FxHashMap<ast::PredicateIndex, bool>;
68type DepGraph = FxHashMap<ast::PredicateIndex, EdgeMap>;
69type Nodeset = FxHashSet<ast::PredicateIndex>;
70
71impl<'p> Program<'p> {
72 pub fn new(arena: &'p Arena) -> Self {
73 Self {
74 arena,
75 ext_preds: Vec::new(),
76 rules: FxHashMap::default(),
77 }
78 }
79
80 pub fn add_clause<'src>(&mut self, src: &'src Arena, clause: &'src ast::Clause) {
81 let clause = self.arena.copy_clause(src, clause);
82 let sym = clause.head.sym;
83 use std::collections::hash_map::Entry;
84 match self.rules.entry(sym) {
85 Entry::Occupied(mut v) => v.get_mut().push(clause),
86 Entry::Vacant(v) => {
87 v.insert(vec![clause]);
88 }
89 }
90 }
91
92 pub fn arena(&'p self) -> &'p ast::Arena {
94 self.arena
95 }
96
97 pub fn extensional_preds(&'p self) -> PredicateSet {
99 let mut set = FxHashSet::default();
100 set.extend(&self.ext_preds);
101 set
102 }
103
104 pub fn intensional_preds(&'p self) -> PredicateSet {
106 let mut set = FxHashSet::default();
107 set.extend(self.rules.keys());
108 set
109 }
110
111 pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
113 self.rules.get(&sym).unwrap().iter().copied()
114 }
115
116 pub fn stratify(self) -> Result<StratifiedProgram<'p>, String> {
125 let dep = make_dep_graph(&self);
126 let mut strata = dep.sccs();
127
128 let mut pred_to_stratum: FxHashMap<ast::PredicateIndex, usize> = FxHashMap::default();
129
130 for (i, c) in strata.iter().enumerate() {
131 for sym in c {
132 pred_to_stratum.insert(*sym, i);
133 }
134 for sym in c {
135 if let Some(edges) = dep.get(sym) {
136 for (dest, negated) in edges {
137 if !*negated {
138 continue;
139 }
140 let dest_stratum = pred_to_stratum.get(dest);
141 if let Some(dest_stratum) = dest_stratum
142 && *dest_stratum == i
143 {
144 return Err("program cannot be stratified".to_string());
145 }
146 }
147 }
148 }
149 }
150 dep.sort_result(&mut strata, pred_to_stratum);
151 let stratified = StratifiedProgram {
152 program: self,
153 strata: strata.into_iter().collect(),
154 };
155 Ok(stratified)
156 }
157}
158
159impl<'p> StratifiedProgram<'p> {
160 pub fn arena(&'p self) -> &'p ast::Arena {
162 self.program.arena()
163 }
164
165 pub fn extensional_preds(&'p self) -> PredicateSet {
167 self.program.extensional_preds()
168 }
169
170 pub fn intensional_preds(&'p self) -> PredicateSet {
172 self.program.intensional_preds()
173 }
174
175 pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
177 self.program.rules(sym)
178 }
179
180 pub fn strata(&'p self) -> Vec<PredicateSet> {
183 self.strata.to_vec()
184 }
185
186 pub fn pred_to_index(&'p self, sym: ast::PredicateIndex) -> Option<usize> {
189 self.strata.iter().position(|x| x.contains(&sym))
190 }
191}
192
193fn make_dep_graph<'p>(program: &Program<'p>) -> DepGraph {
194 let mut dep: DepGraph = FxHashMap::default();
195
196 for (s, rule) in program.rules.iter() {
197 dep.init_node(*s);
198 for clause in rule.iter() {
199 for premise in clause.premises.iter() {
200 match premise {
201 ast::Term::Atom(atom_pred) => {
202 if !program.extensional_preds().contains(&atom_pred.sym) {
203 if clause.transform.is_empty() || clause.transform[0].var.is_some() {
204 dep.add_edge(*s, atom_pred.sym, false);
205 } else {
206 dep.add_edge(*s, atom_pred.sym, true);
207 }
208 }
209 }
210 ast::Term::NegAtom(atom_pred) => {
211 if !program.extensional_preds().contains(&atom_pred.sym) {
212 dep.add_edge(*s, atom_pred.sym, true);
213 }
214 }
215 ast::Term::TemporalAtom(atom_pred, _) => {
216 if !program.extensional_preds().contains(&atom_pred.sym) {
217 if clause.transform.is_empty() || clause.transform[0].var.is_some() {
218 dep.add_edge(*s, atom_pred.sym, false);
219 } else {
220 dep.add_edge(*s, atom_pred.sym, true);
221 }
222 }
223 }
224 _ => {}
225 }
226 }
227 }
228 }
229 dep
230}
231
232fn apply_permutation_cycle_rotate<T: Default>(arr: &mut [T], permutation: &[usize]) {
233 let n = arr.len();
234 if n == 0 {
235 return;
236 }
237 let mut visited = vec![false; n];
238 for i in 0..n {
239 if !visited[i] {
240 let mut current_idx = i;
241 if permutation[current_idx] == i {
242 visited[i] = true;
243 continue;
244 }
245 let mut current_val = std::mem::take(&mut arr[i]);
246 loop {
247 let target_idx = permutation[current_idx];
248 visited[current_idx] = true;
249 let next_val = std::mem::replace(&mut arr[target_idx], current_val);
250 current_val = next_val;
251 current_idx = target_idx;
252 if current_idx == i {
253 break;
254 }
255 }
256 }
257 }
258}
259
260trait DepGraphExt {
261 fn init_node(&mut self, src: ast::PredicateIndex);
262 fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool);
263 fn transpose(&self) -> DepGraph;
264 fn sccs(&self) -> Vec<Nodeset>;
265 fn sort_result(
266 &self,
267 strata: &mut Vec<Nodeset>,
268 pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
269 ) -> FxHashMap<ast::PredicateIndex, usize>;
270}
271
272impl DepGraphExt for DepGraph {
273 fn init_node(&mut self, src: ast::PredicateIndex) {
274 self.entry(src).or_default();
275 }
276
277 fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool) {
278 let edges = self.entry(src).or_default();
279 if negated {
280 edges.insert(dest, negated);
281 return;
282 }
283 if edges.get(&dest).is_none() || !edges[&dest] {
284 edges.insert(dest, false);
285 }
286 }
287
288 fn transpose(&self) -> DepGraph {
289 let mut rev: DepGraph = FxHashMap::default();
290 for (src, edges) in self.iter() {
291 for (dest, negated) in edges.iter() {
292 rev.init_node(*dest);
293 rev.add_edge(*dest, *src, *negated);
294 }
295 }
296 rev
297 }
298
299 fn sccs(&self) -> Vec<Nodeset> {
300 let mut s: Vec<ast::PredicateIndex> = Vec::new();
301 let mut seen: Nodeset = FxHashSet::default();
302
303 fn visit(
304 node: ast::PredicateIndex,
305 graph: &DepGraph,
306 s: &mut Vec<ast::PredicateIndex>,
307 seen: &mut Nodeset,
308 ) {
309 if !seen.contains(&node) {
310 seen.insert(node);
311 if let Some(edges) = graph.get(&node) {
312 for &neighbor in edges.keys() {
313 visit(neighbor, graph, s, seen);
314 }
315 }
316 s.push(node);
317 }
318 }
319
320 for (node, _) in self.iter() {
321 visit(*node, self, &mut s, &mut seen);
322 }
323
324 let rev = self.transpose();
325 let mut seen: Nodeset = FxHashSet::default();
326 fn rvisit(
327 node: ast::PredicateIndex,
328 rev: &DepGraph,
329 scc: &mut Nodeset,
330 seen: &mut Nodeset,
331 ) {
332 if !seen.contains(&node) {
333 seen.insert(node);
334 scc.insert(node);
335 if let Some(edges) = rev.get(&node) {
336 for &e in edges.keys() {
337 rvisit(e, rev, scc, seen);
338 }
339 }
340 }
341 }
342 let mut sccs: Vec<Nodeset> = Vec::new();
343 while let Some(top) = s.pop() {
344 if !seen.contains(&top) {
345 let mut scc: Nodeset = FxHashSet::default();
346 rvisit(top, &rev, &mut scc, &mut seen);
347 if !scc.is_empty() {
348 sccs.push(scc);
349 }
350 }
351 }
352 sccs
353 }
354
355 fn sort_result(
356 &self,
357 strata: &mut Vec<Nodeset>,
358 pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
359 ) -> FxHashMap<ast::PredicateIndex, usize> {
360 let mut sorted_indices: Vec<usize> = Vec::new();
361 let mut seen: FxHashSet<usize> = FxHashSet::default();
362 let num_strata = strata.len();
363
364 fn visit_stratum(
365 index: usize,
366 dep: &DepGraph,
367 strata: &Vec<Nodeset>,
368 pred_to_stratum_map: &FxHashMap<ast::PredicateIndex, usize>,
369 seen: &mut FxHashSet<usize>,
370 sorted_indices: &mut Vec<usize>,
371 ) {
372 if seen.contains(&index) {
373 return;
374 }
375 seen.insert(index);
376
377 if let Some(scc) = strata.get(index) {
378 for sym in scc {
379 if let Some(edges) = dep.get(sym) {
380 for d in edges.keys() {
381 if let Some(&dep_stratum_index) = pred_to_stratum_map.get(d) {
382 visit_stratum(
383 dep_stratum_index,
384 dep,
385 strata,
386 pred_to_stratum_map,
387 seen,
388 sorted_indices,
389 );
390 }
391 }
392 }
393 }
394 }
395 sorted_indices.push(index);
396 }
397
398 for i in 0..num_strata {
399 visit_stratum(
400 i,
401 self,
402 strata,
403 &pred_to_stratum_map,
404 &mut seen,
405 &mut sorted_indices,
406 );
407 }
408
409 let mut permutation = vec![0; num_strata];
410 let mut old_to_new_map: FxHashMap<usize, usize> = FxHashMap::default();
411 for new_idx in 0..num_strata {
412 let old_idx = sorted_indices[new_idx];
413 permutation[old_idx] = new_idx;
414 old_to_new_map.insert(old_idx, new_idx);
415 }
416
417 apply_permutation_cycle_rotate(strata, &permutation);
418
419 let mut new_pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize> =
420 FxHashMap::default();
421 for (sym, &old_idx) in pred_to_stratum_map.iter() {
422 if let Some(&new_idx) = old_to_new_map.get(&old_idx) {
423 new_pred_to_stratum_map.insert(*sym, new_idx);
424 }
425 }
426 new_pred_to_stratum_map
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use mangle_parse::Parser;
434
435 #[test]
436 fn test_stratification_success() {
437 let arena = Arena::new_with_global_interner();
438 let source = r#"
439 p(1).
440 q(X) :- p(X).
441 r(X) :- q(X), !s(X).
442 s(2).
443 "#;
444 let mut parser = Parser::new(&arena, source.as_bytes(), "test");
445 parser.next_token().unwrap();
446 let unit = parser.parse_unit().unwrap();
447
448 let mut program = Program::new(&arena);
449 for clause in unit.clauses {
450 program.add_clause(&arena, clause);
451 }
452
453 let stratified = program.stratify().expect("should be stratifiable");
454
455 let get_stratum = |name: &str| -> Option<usize> {
457 let name_idx = arena.lookup_opt(name)?;
458 let pred_idx = arena.lookup_predicate_sym(name_idx)?;
459 stratified.pred_to_index(pred_idx)
460 };
461
462 let s_idx = get_stratum("s");
463 let r_idx = get_stratum("r");
464 let q_idx = get_stratum("q");
465 let p_idx = get_stratum("p");
466
467 assert!(s_idx.is_some());
468 assert!(r_idx.is_some());
469 assert!(q_idx.is_some());
470 assert!(p_idx.is_some());
471
472 assert!(r_idx.unwrap() > s_idx.unwrap(), "r should be higher than s");
474
475 assert!(q_idx.unwrap() >= p_idx.unwrap(), "q should be >= p");
477
478 assert!(r_idx.unwrap() >= q_idx.unwrap(), "r should be >= q");
480 }
481
482 #[test]
483 fn test_stratification_cycle() {
484 let arena = Arena::new_with_global_interner();
485 let source = "p(X) :- !p(X).";
486 let mut parser = Parser::new(&arena, source.as_bytes(), "test");
487 parser.next_token().unwrap();
488 let unit = parser.parse_unit().unwrap();
489
490 let mut program = Program::new(&arena);
491 for clause in unit.clauses {
492 program.add_clause(&arena, clause);
493 }
494
495 let res = program.stratify();
496 assert!(res.is_err(), "should detect negation cycle");
497 }
498}