1use fxhash::{FxHashMap, FxHashSet};
16use mangle_ast as ast;
17
18pub fn rewrite_unit<'a>(arena: &'a ast::Arena, unit: &'a ast::Unit<'a>) -> ast::Unit<'a> {
20 let (pkg_name, used_pkgs) = find_package_info(arena, unit);
21
22 if pkg_name.is_empty() {
23 return ast::Unit {
24 decls: unit.decls,
25 clauses: unit.clauses,
26 };
27 }
28
29 let defined_preds = find_defined_preds(unit);
30 let mut renamer = Renamer {
31 arena,
32 pkg_name,
33 used_pkgs,
34 defined_preds,
35 cache: FxHashMap::default(),
36 };
37
38 let mut new_decls = Vec::with_capacity(unit.decls.len());
39 for &decl in unit.decls {
40 if let Some(new_decl) = renamer.rewrite_decl(decl) {
41 new_decls.push(new_decl);
42 }
43 }
44
45 let mut new_clauses = Vec::with_capacity(unit.clauses.len());
46 for &clause in unit.clauses {
47 if let Some(new_clause) = renamer.rewrite_clause(clause) {
48 new_clauses.push(new_clause);
49 }
50 }
51
52 ast::Unit {
53 decls: arena.alloc_slice_copy(&new_decls),
54 clauses: arena.alloc_slice_copy(&new_clauses),
55 }
56}
57
58fn find_package_info<'a>(
59 arena: &'a ast::Arena,
60 unit: &'a ast::Unit<'a>,
61) -> (&'a str, FxHashSet<&'a str>) {
62 let mut pkg_name = "";
63 let mut used_pkgs = FxHashSet::default();
64
65 for &decl in unit.decls {
66 let pred_name = arena.predicate_name(decl.atom.sym).unwrap_or("");
67 if pred_name == "Package" {
68 if let Some(desc) = find_name_desc(arena, decl.descr) {
69 pkg_name = desc;
70 }
71 } else if pred_name == "Use"
72 && let Some(desc) = find_name_desc(arena, decl.descr)
73 {
74 used_pkgs.insert(desc);
75 }
76 }
77 (pkg_name, used_pkgs)
78}
79
80fn find_name_desc<'a>(arena: &'a ast::Arena, descr: &'a [&'a ast::Atom<'a>]) -> Option<&'a str> {
81 for &atom in descr {
82 if arena.predicate_name(atom.sym).unwrap_or("") == "name"
83 && let Some(&ast::BaseTerm::Const(ast::Const::String(s))) = atom.args.first().copied()
84 {
85 return Some(s);
86 }
87 }
88 None
89}
90
91fn find_defined_preds(unit: &ast::Unit) -> FxHashSet<ast::PredicateIndex> {
92 let mut defined = FxHashSet::default();
93 for &decl in unit.decls {
94 defined.insert(decl.atom.sym);
95 }
96 for &clause in unit.clauses {
97 defined.insert(clause.head.sym);
98 }
99 defined
100}
101
102struct Renamer<'a> {
103 arena: &'a ast::Arena,
104 pkg_name: &'a str,
105 used_pkgs: FxHashSet<&'a str>,
106 defined_preds: FxHashSet<ast::PredicateIndex>,
107 cache: FxHashMap<ast::PredicateIndex, ast::PredicateIndex>,
108}
109
110impl<'a> Renamer<'a> {
111 fn rename_pred(&mut self, sym: ast::PredicateIndex) -> Option<ast::PredicateIndex> {
112 if let Some(&new_sym) = self.cache.get(&sym) {
113 return Some(new_sym);
114 }
115
116 let name = self.arena.predicate_name(sym)?;
117
118 if name == "Package" || name == "Use" {
120 self.cache.insert(sym, sym);
121 return Some(sym);
122 }
123
124 if self.defined_preds.contains(&sym) {
125 let new_name = format!("{}.{}", self.pkg_name, name);
126 let new_sym = self
127 .arena
128 .predicate_sym(&new_name, self.arena.predicate_arity(sym));
129 self.cache.insert(sym, new_sym);
130 return Some(new_sym);
131 }
132
133 if let Some(dot_idx) = name.rfind('.') {
135 let prefix = &name[..dot_idx];
136 if self.used_pkgs.contains(prefix) {
137 self.cache.insert(sym, sym);
139 return Some(sym);
140 }
141 }
142
143 self.cache.insert(sym, sym);
145 Some(sym)
146 }
147
148 fn rewrite_decl(&mut self, decl: &'a ast::Decl<'a>) -> Option<&'a ast::Decl<'a>> {
149 let pred_name = self.arena.predicate_name(decl.atom.sym).unwrap_or("");
150 if pred_name == "Package" || pred_name == "Use" {
151 return None;
153 }
154
155 let new_atom = self.rewrite_atom(decl.atom)?;
156
157 let bounds = if let Some(bs) = decl.bounds {
159 let mut new_bounds = Vec::new();
160 for &b in bs {
161 let new_base_terms: Vec<&ast::BaseTerm> = b
162 .base_terms
163 .iter()
164 .map(|&t| {
165 if let ast::BaseTerm::Const(ast::Const::Name(name_idx)) = t {
167 let name = self.arena.lookup_name(*name_idx).unwrap_or("");
168 if let Some(pred_idx) = self.arena.lookup_predicate_sym(*name_idx)
170 && self.defined_preds.contains(&pred_idx)
171 {
172 let new_name = format!("{}.{}", self.pkg_name, name);
173 let new_name_idx = self.arena.intern(&new_name);
174 return &*self
175 .arena
176 .alloc(ast::BaseTerm::Const(ast::Const::Name(new_name_idx)));
177 }
178 }
179 t
180 })
181 .collect();
182 new_bounds.push(&*self.arena.alloc(ast::BoundDecl {
183 base_terms: self.arena.alloc_slice_copy(&new_base_terms),
184 }));
185 }
186 Some(self.arena.alloc_slice_copy(&new_bounds) as &'a [&'a ast::BoundDecl<'a>])
187 } else {
188 None
189 };
190
191 Some(self.arena.alloc(ast::Decl {
192 atom: new_atom,
193 descr: decl.descr, bounds,
195 constraints: decl.constraints,
196 }))
197 }
198
199 fn rewrite_clause(&mut self, clause: &'a ast::Clause<'a>) -> Option<&'a ast::Clause<'a>> {
200 let head = self.rewrite_atom(clause.head)?;
201 let mut premises = Vec::new();
202 for &premise in clause.premises {
203 match premise {
204 ast::Term::Atom(a) => {
205 premises.push(&*self.arena.alloc(ast::Term::Atom(self.rewrite_atom(a)?)));
206 }
207 ast::Term::NegAtom(a) => {
208 premises.push(&*self.arena.alloc(ast::Term::NegAtom(self.rewrite_atom(a)?)));
209 }
210 _ => premises.push(premise),
211 }
212 }
213
214 Some(self.arena.alloc(ast::Clause {
215 head,
216 premises: self.arena.alloc_slice_copy(&premises),
217 transform: clause.transform,
218 }))
219 }
220
221 fn rewrite_atom(&mut self, atom: &'a ast::Atom<'a>) -> Option<&'a ast::Atom<'a>> {
222 let new_sym = self.rename_pred(atom.sym)?;
223 Some(self.arena.atom(new_sym, atom.args))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use googletest::prelude::*;
231
232 #[test]
233 fn test_rename_simple() {
234 let arena = ast::Arena::new_with_global_interner();
235
236 let pkg_sym = arena.predicate_sym("Package", Some(0));
241 let name_sym = arena.predicate_sym("name", Some(1));
242 let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("pkg")));
243 let pkg_decl = arena.alloc(ast::Decl {
244 atom: arena.atom(pkg_sym, &[]),
245 descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
246 bounds: None,
247 constraints: None,
248 });
249
250 let foo_sym = arena.predicate_sym("foo", Some(1));
251 let bar_sym = arena.predicate_sym("bar", Some(1));
252 let var_x = arena.variable("X");
253 let const_1 = arena.const_(ast::Const::Number(1));
254
255 let clause1 = arena.alloc(ast::Clause {
256 head: arena.atom(foo_sym, &[var_x]),
257 premises: arena
258 .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(bar_sym, &[var_x])))]),
259 transform: &[],
260 });
261
262 let clause2 = arena.alloc(ast::Clause {
263 head: arena.atom(bar_sym, &[const_1]),
264 premises: &[],
265 transform: &[],
266 });
267
268 let unit = ast::Unit {
269 decls: arena.alloc_slice_copy(&[pkg_decl]),
270 clauses: arena.alloc_slice_copy(&[clause1, clause2]),
271 };
272
273 let new_unit = rewrite_unit(&arena, &unit);
274
275 assert_that!(new_unit.decls.len(), eq(0));
277 assert_that!(new_unit.clauses.len(), eq(2));
278
279 let c1 = new_unit.clauses[0];
280 let c2 = new_unit.clauses[1];
281
282 let head_name = arena.predicate_name(c1.head.sym).unwrap();
284 assert_that!(head_name, eq("pkg.foo"));
285
286 if let ast::Term::Atom(a) = c1.premises[0] {
288 let p_name = arena.predicate_name(a.sym).unwrap();
289 assert_that!(p_name, eq("pkg.bar"));
290 } else {
291 panic!("Expected Atom premise");
292 }
293
294 let head_name_2 = arena.predicate_name(c2.head.sym).unwrap();
296 assert_that!(head_name_2, eq("pkg.bar"));
297 }
298
299 #[test]
300 fn test_rename_with_use() {
301 let arena = ast::Arena::new_with_global_interner();
302
303 let pkg_sym = arena.predicate_sym("Package", Some(0));
308 let use_sym = arena.predicate_sym("Use", Some(0));
309 let name_sym = arena.predicate_sym("name", Some(1));
310
311 let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("pkg")));
312 let pkg_decl = arena.alloc(ast::Decl {
313 atom: arena.atom(pkg_sym, &[]),
314 descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
315 bounds: None,
316 constraints: None,
317 });
318
319 let other_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("other")));
320 let use_decl = arena.alloc(ast::Decl {
321 atom: arena.atom(use_sym, &[]),
322 descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[other_name])]),
323 bounds: None,
324 constraints: None,
325 });
326
327 let foo_sym = arena.predicate_sym("foo", Some(1));
328 let other_bar_sym = arena.predicate_sym("other.bar", Some(1));
329 let var_x = arena.variable("X");
330
331 let clause1 = arena.alloc(ast::Clause {
332 head: arena.atom(foo_sym, &[var_x]),
333 premises: arena.alloc_slice_copy(&[
334 arena.alloc(ast::Term::Atom(arena.atom(other_bar_sym, &[var_x])))
335 ]),
336 transform: &[],
337 });
338
339 let unit = ast::Unit {
340 decls: arena.alloc_slice_copy(&[pkg_decl, use_decl]),
341 clauses: arena.alloc_slice_copy(&[clause1]),
342 };
343
344 let new_unit = rewrite_unit(&arena, &unit);
345
346 assert_that!(new_unit.clauses.len(), eq(1));
347 let c1 = new_unit.clauses[0];
348
349 let head_name = arena.predicate_name(c1.head.sym).unwrap();
351 assert_that!(head_name, eq("pkg.foo"));
352
353 if let ast::Term::Atom(a) = c1.premises[0] {
355 let p_name = arena.predicate_name(a.sym).unwrap();
356 assert_that!(p_name, eq("other.bar"));
357 } else {
358 panic!("Expected Atom premise");
359 }
360 }
361
362 fn make_pkg_decl<'a>(arena: &'a ast::Arena, name: &str) -> &'a ast::Decl<'a> {
363 let pkg_sym = arena.predicate_sym("Package", Some(0));
364 let name_sym = arena.predicate_sym("name", Some(1));
365 let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String(
366 arena.alloc_str(name),
367 )));
368 arena.alloc(ast::Decl {
369 atom: arena.atom(pkg_sym, &[]),
370 descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
371 bounds: None,
372 constraints: None,
373 })
374 }
375
376 #[test]
377 fn test_go_case_no_package() {
378 let arena = ast::Arena::new_with_global_interner();
380 let clause = arena.alloc(ast::Clause {
381 head: arena.atom(arena.predicate_sym("clause_defined_here", None), &[]),
382 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(
383 arena.atom(arena.predicate_sym("other_clause", None), &[]),
384 ))]),
385 transform: &[],
386 });
387 let unit = ast::Unit {
388 decls: &[],
389 clauses: arena.alloc_slice_copy(&[clause]),
390 };
391 let new_unit = rewrite_unit(&arena, &unit);
392 let head = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
393 assert_that!(head, eq("clause_defined_here"));
394 if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
395 let p = arena.predicate_name(a.sym).unwrap();
396 assert_that!(p, eq("other_clause"));
397 }
398 }
399
400 #[test]
401 fn test_go_case_external_refs() {
402 let arena = ast::Arena::new_with_global_interner();
404 let pkg_decl = make_pkg_decl(&arena, "foo.bar");
406
407 let clause = arena.alloc(ast::Clause {
410 head: arena.atom(arena.predicate_sym("clause_defined_here", None), &[]),
411 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(
412 arena.atom(arena.predicate_sym("other_clause", None), &[]),
413 ))]),
414 transform: &[],
415 });
416
417 let unit = ast::Unit {
418 decls: arena.alloc_slice_copy(&[pkg_decl]),
419 clauses: arena.alloc_slice_copy(&[clause]),
420 };
421 let new_unit = rewrite_unit(&arena, &unit);
422
423 let head = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
425 assert_that!(head, eq("foo.bar.clause_defined_here"));
426
427 if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
429 let p = arena.predicate_name(a.sym).unwrap();
430 assert_that!(p, eq("other_clause"));
431 }
432 }
433
434 #[test]
435 fn test_go_case_rewritten_local() {
436 let arena = ast::Arena::new_with_global_interner();
438 let pkg_decl = make_pkg_decl(&arena, "foo.bar");
439
440 let defined_sym = arena.predicate_sym("clause_defined_here", None);
441 let other_sym = arena.predicate_sym("other_clause", None);
442
443 let clause1 = arena.alloc(ast::Clause {
445 head: arena.atom(other_sym, &[]),
446 premises: &[],
447 transform: &[],
448 });
449
450 let clause2 = arena.alloc(ast::Clause {
452 head: arena.atom(defined_sym, &[]),
453 premises: arena
454 .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(other_sym, &[])))]),
455 transform: &[],
456 });
457
458 let unit = ast::Unit {
459 decls: arena.alloc_slice_copy(&[pkg_decl]),
460 clauses: arena.alloc_slice_copy(&[clause1, clause2]),
461 };
462 let new_unit = rewrite_unit(&arena, &unit);
463
464 let h1 = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
466 assert_that!(h1, eq("foo.bar.other_clause"));
467
468 let h2 = arena.predicate_name(new_unit.clauses[1].head.sym).unwrap();
470 assert_that!(h2, eq("foo.bar.clause_defined_here"));
471
472 if let ast::Term::Atom(a) = new_unit.clauses[1].premises[0] {
474 let p = arena.predicate_name(a.sym).unwrap();
475 assert_that!(p, eq("foo.bar.other_clause"));
476 }
477 }
478
479 #[test]
480 fn test_go_case_negation() {
481 let arena = ast::Arena::new_with_global_interner();
483 let pkg_decl = make_pkg_decl(&arena, "foo.bar");
484
485 let defined_sym = arena.predicate_sym("clause_defined_here", None);
486 let other_sym = arena.predicate_sym("other_clause", None);
487
488 let clause1 = arena.alloc(ast::Clause {
490 head: arena.atom(other_sym, &[]),
491 premises: &[],
492 transform: &[],
493 });
494
495 let clause2 = arena.alloc(ast::Clause {
497 head: arena.atom(defined_sym, &[]),
498 premises: arena
499 .alloc_slice_copy(&[arena.alloc(ast::Term::NegAtom(arena.atom(other_sym, &[])))]),
500 transform: &[],
501 });
502
503 let unit = ast::Unit {
504 decls: arena.alloc_slice_copy(&[pkg_decl]),
505 clauses: arena.alloc_slice_copy(&[clause1, clause2]),
506 };
507 let new_unit = rewrite_unit(&arena, &unit);
508
509 if let ast::Term::NegAtom(a) = new_unit.clauses[1].premises[0] {
510 let p = arena.predicate_name(a.sym).unwrap();
511 assert_that!(p, eq("foo.bar.other_clause"));
512 } else {
513 panic!("Expected NegAtom");
514 }
515 }
516
517 #[test]
518 fn test_go_case_decl_only() {
519 let arena = ast::Arena::new_with_global_interner();
521 let pkg_decl = make_pkg_decl(&arena, "foo.bar");
522
523 let clause_sym = arena.predicate_sym("clause", None);
524 let decl_sym = arena.predicate_sym("from_decl", None);
525
526 let decl = arena.alloc(ast::Decl {
528 atom: arena.atom(decl_sym, &[]),
529 descr: &[],
530 bounds: None,
531 constraints: None,
532 });
533
534 let clause = arena.alloc(ast::Clause {
536 head: arena.atom(clause_sym, &[]),
537 premises: arena
538 .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(decl_sym, &[])))]),
539 transform: &[],
540 });
541
542 let unit = ast::Unit {
543 decls: arena.alloc_slice_copy(&[pkg_decl, decl]),
544 clauses: arena.alloc_slice_copy(&[clause]),
545 };
546 let new_unit = rewrite_unit(&arena, &unit);
547
548 let h = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
550 assert_that!(h, eq("foo.bar.clause"));
551
552 if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
554 let p = arena.predicate_name(a.sym).unwrap();
555 assert_that!(p, eq("foo.bar.from_decl"));
556 }
557
558 let d_name = arena.predicate_name(new_unit.decls[0].atom.sym).unwrap();
560 assert_that!(d_name, eq("foo.bar.from_decl"));
561 }
562}