1use crate::PredicateSet;
16use fxhash::{FxHashMap, FxHashSet};
17use mangle_ast as ast;
18use mangle_ast::Arena;
19use std::fmt;
20
21#[derive(Clone)]
29pub struct Program<'p> {
30 pub arena: &'p Arena,
31 pub ext_preds: Vec<ast::PredicateIndex>,
32 pub rules: FxHashMap<ast::PredicateIndex, Vec<&'p ast::Clause<'p>>>,
33}
34
35impl<'p> fmt::Debug for Program<'p> {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("Program")
38 .field("ext_preds", &self.ext_preds)
39 .field("rules", &self.rules)
40 .finish()
41 }
42}
43
44#[derive(Clone)]
53pub struct StratifiedProgram<'p> {
54 program: Program<'p>,
55 strata: Vec<PredicateSet>,
56}
57
58impl<'p> fmt::Debug for StratifiedProgram<'p> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("StratifiedProgram")
61 .field("program", &self.program)
62 .field("strata", &self.strata)
63 .finish()
64 }
65}
66
67type EdgeMap = FxHashMap<ast::PredicateIndex, bool>;
68type DepGraph = FxHashMap<ast::PredicateIndex, EdgeMap>;
69type Nodeset = FxHashSet<ast::PredicateIndex>;
70
71impl<'p> Program<'p> {
72 pub fn new(arena: &'p Arena) -> Self {
73 Self {
74 arena,
75 ext_preds: Vec::new(),
76 rules: FxHashMap::default(),
77 }
78 }
79
80 pub fn add_clause<'src>(&mut self, src: &'src Arena, clause: &'src ast::Clause) {
81 let clause = self.arena.copy_clause(src, clause);
82 let sym = clause.head.sym;
83 use std::collections::hash_map::Entry;
84 match self.rules.entry(sym) {
85 Entry::Occupied(mut v) => v.get_mut().push(clause),
86 Entry::Vacant(v) => {
87 v.insert(vec![clause]);
88 }
89 }
90 }
91
92 pub fn arena(&'p self) -> &'p ast::Arena {
94 self.arena
95 }
96
97 pub fn extensional_preds(&'p self) -> PredicateSet {
99 let mut set = FxHashSet::default();
100 set.extend(&self.ext_preds);
101 set
102 }
103
104 pub fn intensional_preds(&'p self) -> PredicateSet {
106 let mut set = FxHashSet::default();
107 set.extend(self.rules.keys());
108 set
109 }
110
111 pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
113 self.rules.get(&sym).unwrap().iter().copied()
114 }
115
116 pub fn stratify(self) -> Result<StratifiedProgram<'p>, String> {
125 let dep = make_dep_graph(&self);
126 let mut strata = dep.sccs();
127
128 let mut pred_to_stratum: FxHashMap<ast::PredicateIndex, usize> = FxHashMap::default();
129
130 for (i, c) in strata.iter().enumerate() {
131 for sym in c {
132 pred_to_stratum.insert(*sym, i);
133 }
134 for sym in c {
135 if let Some(edges) = dep.get(sym) {
136 for (dest, negated) in edges {
137 if !*negated {
138 continue;
139 }
140 let dest_stratum = pred_to_stratum.get(dest);
141 if let Some(dest_stratum) = dest_stratum
142 && *dest_stratum == i
143 {
144 return Err("program cannot be stratified".to_string());
145 }
146 }
147 }
148 }
149 }
150 dep.sort_result(&mut strata, pred_to_stratum);
151 let stratified = StratifiedProgram {
152 program: self,
153 strata: strata.into_iter().collect(),
154 };
155 Ok(stratified)
156 }
157}
158
159impl<'p> StratifiedProgram<'p> {
160 pub fn arena(&'p self) -> &'p ast::Arena {
162 self.program.arena()
163 }
164
165 pub fn extensional_preds(&'p self) -> PredicateSet {
167 self.program.extensional_preds()
168 }
169
170 pub fn intensional_preds(&'p self) -> PredicateSet {
172 self.program.intensional_preds()
173 }
174
175 pub fn rules(&'p self, sym: ast::PredicateIndex) -> impl Iterator<Item = &'p ast::Clause<'p>> {
177 self.program.rules(sym)
178 }
179
180 pub fn strata(&'p self) -> Vec<PredicateSet> {
183 self.strata.to_vec()
184 }
185
186 pub fn pred_to_index(&'p self, sym: ast::PredicateIndex) -> Option<usize> {
189 self.strata.iter().position(|x| x.contains(&sym))
190 }
191}
192
193fn make_dep_graph<'p>(program: &Program<'p>) -> DepGraph {
194 let mut dep: DepGraph = FxHashMap::default();
195
196 for (s, rule) in program.rules.iter() {
197 dep.init_node(*s);
198 for clause in rule.iter() {
199 for premise in clause.premises.iter() {
200 match premise {
201 ast::Term::Atom(atom_pred) => {
202 if !program.extensional_preds().contains(&atom_pred.sym) {
203 if clause.transform.is_empty() || clause.transform[0].var.is_some() {
204 dep.add_edge(*s, atom_pred.sym, false);
205 } else {
206 dep.add_edge(*s, atom_pred.sym, true);
207 }
208 }
209 }
210 ast::Term::NegAtom(atom_pred) => {
211 if !program.extensional_preds().contains(&atom_pred.sym) {
212 dep.add_edge(*s, atom_pred.sym, true);
213 }
214 }
215 _ => {}
216 }
217 }
218 }
219 }
220 dep
221}
222
223fn apply_permutation_cycle_rotate<T: Default>(arr: &mut [T], permutation: &[usize]) {
224 let n = arr.len();
225 if n == 0 {
226 return;
227 }
228 let mut visited = vec![false; n];
229 for i in 0..n {
230 if !visited[i] {
231 let mut current_idx = i;
232 if permutation[current_idx] == i {
233 visited[i] = true;
234 continue;
235 }
236 let mut current_val = std::mem::take(&mut arr[i]);
237 loop {
238 let target_idx = permutation[current_idx];
239 visited[current_idx] = true;
240 let next_val = std::mem::replace(&mut arr[target_idx], current_val);
241 current_val = next_val;
242 current_idx = target_idx;
243 if current_idx == i {
244 break;
245 }
246 }
247 }
248 }
249}
250
251trait DepGraphExt {
252 fn init_node(&mut self, src: ast::PredicateIndex);
253 fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool);
254 fn transpose(&self) -> DepGraph;
255 fn sccs(&self) -> Vec<Nodeset>;
256 fn sort_result(
257 &self,
258 strata: &mut Vec<Nodeset>,
259 pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
260 ) -> FxHashMap<ast::PredicateIndex, usize>;
261}
262
263impl DepGraphExt for DepGraph {
264 fn init_node(&mut self, src: ast::PredicateIndex) {
265 self.entry(src).or_default();
266 }
267
268 fn add_edge(&mut self, src: ast::PredicateIndex, dest: ast::PredicateIndex, negated: bool) {
269 let edges = self.entry(src).or_default();
270 if negated {
271 edges.insert(dest, negated);
272 return;
273 }
274 if edges.get(&dest).is_none() || !edges[&dest] {
275 edges.insert(dest, false);
276 }
277 }
278
279 fn transpose(&self) -> DepGraph {
280 let mut rev: DepGraph = FxHashMap::default();
281 for (src, edges) in self.iter() {
282 for (dest, negated) in edges.iter() {
283 rev.init_node(*dest);
284 rev.add_edge(*dest, *src, *negated);
285 }
286 }
287 rev
288 }
289
290 fn sccs(&self) -> Vec<Nodeset> {
291 let mut s: Vec<ast::PredicateIndex> = Vec::new();
292 let mut seen: Nodeset = FxHashSet::default();
293
294 fn visit(
295 node: ast::PredicateIndex,
296 graph: &DepGraph,
297 s: &mut Vec<ast::PredicateIndex>,
298 seen: &mut Nodeset,
299 ) {
300 if !seen.contains(&node) {
301 seen.insert(node);
302 if let Some(edges) = graph.get(&node) {
303 for &neighbor in edges.keys() {
304 visit(neighbor, graph, s, seen);
305 }
306 }
307 s.push(node);
308 }
309 }
310
311 for (node, _) in self.iter() {
312 visit(*node, self, &mut s, &mut seen);
313 }
314
315 let rev = self.transpose();
316 let mut seen: Nodeset = FxHashSet::default();
317 fn rvisit(
318 node: ast::PredicateIndex,
319 rev: &DepGraph,
320 scc: &mut Nodeset,
321 seen: &mut Nodeset,
322 ) {
323 if !seen.contains(&node) {
324 seen.insert(node);
325 scc.insert(node);
326 if let Some(edges) = rev.get(&node) {
327 for &e in edges.keys() {
328 rvisit(e, rev, scc, seen);
329 }
330 }
331 }
332 }
333 let mut sccs: Vec<Nodeset> = Vec::new();
334 while let Some(top) = s.pop() {
335 if !seen.contains(&top) {
336 let mut scc: Nodeset = FxHashSet::default();
337 rvisit(top, &rev, &mut scc, &mut seen);
338 if !scc.is_empty() {
339 sccs.push(scc);
340 }
341 }
342 }
343 sccs
344 }
345
346 fn sort_result(
347 &self,
348 strata: &mut Vec<Nodeset>,
349 pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize>,
350 ) -> FxHashMap<ast::PredicateIndex, usize> {
351 let mut sorted_indices: Vec<usize> = Vec::new();
352 let mut seen: FxHashSet<usize> = FxHashSet::default();
353 let num_strata = strata.len();
354
355 fn visit_stratum(
356 index: usize,
357 dep: &DepGraph,
358 strata: &Vec<Nodeset>,
359 pred_to_stratum_map: &FxHashMap<ast::PredicateIndex, usize>,
360 seen: &mut FxHashSet<usize>,
361 sorted_indices: &mut Vec<usize>,
362 ) {
363 if seen.contains(&index) {
364 return;
365 }
366 seen.insert(index);
367
368 if let Some(scc) = strata.get(index) {
369 for sym in scc {
370 if let Some(edges) = dep.get(sym) {
371 for d in edges.keys() {
372 if let Some(&dep_stratum_index) = pred_to_stratum_map.get(d) {
373 visit_stratum(
374 dep_stratum_index,
375 dep,
376 strata,
377 pred_to_stratum_map,
378 seen,
379 sorted_indices,
380 );
381 }
382 }
383 }
384 }
385 }
386 sorted_indices.push(index);
387 }
388
389 for i in 0..num_strata {
390 visit_stratum(
391 i,
392 self,
393 strata,
394 &pred_to_stratum_map,
395 &mut seen,
396 &mut sorted_indices,
397 );
398 }
399
400 let mut permutation = vec![0; num_strata];
401 let mut old_to_new_map: FxHashMap<usize, usize> = FxHashMap::default();
402 for new_idx in 0..num_strata {
403 let old_idx = sorted_indices[new_idx];
404 permutation[old_idx] = new_idx;
405 old_to_new_map.insert(old_idx, new_idx);
406 }
407
408 apply_permutation_cycle_rotate(strata, &permutation);
409
410 let mut new_pred_to_stratum_map: FxHashMap<ast::PredicateIndex, usize> =
411 FxHashMap::default();
412 for (sym, &old_idx) in pred_to_stratum_map.iter() {
413 if let Some(&new_idx) = old_to_new_map.get(&old_idx) {
414 new_pred_to_stratum_map.insert(*sym, new_idx);
415 }
416 }
417 new_pred_to_stratum_map
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use mangle_parse::Parser;
425
426 #[test]
427 fn test_stratification_success() {
428 let arena = Arena::new_with_global_interner();
429 let source = r#"
430 p(1).
431 q(X) :- p(X).
432 r(X) :- q(X), !s(X).
433 s(2).
434 "#;
435 let mut parser = Parser::new(&arena, source.as_bytes(), "test");
436 parser.next_token().unwrap();
437 let unit = parser.parse_unit().unwrap();
438
439 let mut program = Program::new(&arena);
440 for clause in unit.clauses {
441 program.add_clause(&arena, clause);
442 }
443
444 let stratified = program.stratify().expect("should be stratifiable");
445
446 let get_stratum = |name: &str| -> Option<usize> {
448 let name_idx = arena.lookup_opt(name)?;
449 let pred_idx = arena.lookup_predicate_sym(name_idx)?;
450 stratified.pred_to_index(pred_idx)
451 };
452
453 let s_idx = get_stratum("s");
454 let r_idx = get_stratum("r");
455 let q_idx = get_stratum("q");
456 let p_idx = get_stratum("p");
457
458 assert!(s_idx.is_some());
459 assert!(r_idx.is_some());
460 assert!(q_idx.is_some());
461 assert!(p_idx.is_some());
462
463 assert!(r_idx.unwrap() > s_idx.unwrap(), "r should be higher than s");
465
466 assert!(q_idx.unwrap() >= p_idx.unwrap(), "q should be >= p");
468
469 assert!(r_idx.unwrap() >= q_idx.unwrap(), "r should be >= q");
471 }
472
473 #[test]
474 fn test_stratification_cycle() {
475 let arena = Arena::new_with_global_interner();
476 let source = "p(X) :- !p(X).";
477 let mut parser = Parser::new(&arena, source.as_bytes(), "test");
478 parser.next_token().unwrap();
479 let unit = parser.parse_unit().unwrap();
480
481 let mut program = Program::new(&arena);
482 for clause in unit.clauses {
483 program.add_clause(&arena, clause);
484 }
485
486 let res = program.stratify();
487 assert!(res.is_err(), "should detect negation cycle");
488 }
489}