Skip to main content

mangle_analysis/
rename.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fxhash::{FxHashMap, FxHashSet};
16use mangle_ast as ast;
17
18/// Rewrites a unit by prefixing local predicates with the package name.
19pub fn rewrite_unit<'a>(arena: &'a ast::Arena, unit: &'a ast::Unit<'a>) -> ast::Unit<'a> {
20    let (pkg_name, used_pkgs) = find_package_info(arena, unit);
21
22    if pkg_name.is_empty() {
23        return ast::Unit {
24            decls: unit.decls,
25            clauses: unit.clauses,
26        };
27    }
28
29    let defined_preds = find_defined_preds(unit);
30    let mut renamer = Renamer {
31        arena,
32        pkg_name,
33        used_pkgs,
34        defined_preds,
35        cache: FxHashMap::default(),
36    };
37
38    let mut new_decls = Vec::with_capacity(unit.decls.len());
39    for &decl in unit.decls {
40        if let Some(new_decl) = renamer.rewrite_decl(decl) {
41            new_decls.push(new_decl);
42        }
43    }
44
45    let mut new_clauses = Vec::with_capacity(unit.clauses.len());
46    for &clause in unit.clauses {
47        if let Some(new_clause) = renamer.rewrite_clause(clause) {
48            new_clauses.push(new_clause);
49        }
50    }
51
52    ast::Unit {
53        decls: arena.alloc_slice_copy(&new_decls),
54        clauses: arena.alloc_slice_copy(&new_clauses),
55    }
56}
57
58fn find_package_info<'a>(
59    arena: &'a ast::Arena,
60    unit: &'a ast::Unit<'a>,
61) -> (&'a str, FxHashSet<&'a str>) {
62    let mut pkg_name = "";
63    let mut used_pkgs = FxHashSet::default();
64
65    for &decl in unit.decls {
66        let pred_name = arena.predicate_name(decl.atom.sym).unwrap_or("");
67        if pred_name == "Package" {
68            if let Some(desc) = find_name_desc(arena, decl.descr) {
69                pkg_name = desc;
70            }
71        } else if pred_name == "Use"
72            && let Some(desc) = find_name_desc(arena, decl.descr)
73        {
74            used_pkgs.insert(desc);
75        }
76    }
77    (pkg_name, used_pkgs)
78}
79
80fn find_name_desc<'a>(arena: &'a ast::Arena, descr: &'a [&'a ast::Atom<'a>]) -> Option<&'a str> {
81    for &atom in descr {
82        if arena.predicate_name(atom.sym).unwrap_or("") == "name"
83            && let Some(&ast::BaseTerm::Const(ast::Const::String(s))) = atom.args.first().copied()
84        {
85            return Some(s);
86        }
87    }
88    None
89}
90
91fn find_defined_preds(unit: &ast::Unit) -> FxHashSet<ast::PredicateIndex> {
92    let mut defined = FxHashSet::default();
93    for &decl in unit.decls {
94        defined.insert(decl.atom.sym);
95    }
96    for &clause in unit.clauses {
97        defined.insert(clause.head.sym);
98    }
99    defined
100}
101
102struct Renamer<'a> {
103    arena: &'a ast::Arena,
104    pkg_name: &'a str,
105    used_pkgs: FxHashSet<&'a str>,
106    defined_preds: FxHashSet<ast::PredicateIndex>,
107    cache: FxHashMap<ast::PredicateIndex, ast::PredicateIndex>,
108}
109
110impl<'a> Renamer<'a> {
111    fn rename_pred(&mut self, sym: ast::PredicateIndex) -> Option<ast::PredicateIndex> {
112        if let Some(&new_sym) = self.cache.get(&sym) {
113            return Some(new_sym);
114        }
115
116        let name = self.arena.predicate_name(sym)?;
117
118        // Don't rename Package and Use predicates themselves
119        if name == "Package" || name == "Use" {
120            self.cache.insert(sym, sym);
121            return Some(sym);
122        }
123
124        if self.defined_preds.contains(&sym) {
125            let new_name = format!("{}.{}", self.pkg_name, name);
126            let new_sym = self
127                .arena
128                .predicate_sym(&new_name, self.arena.predicate_arity(sym));
129            self.cache.insert(sym, new_sym);
130            return Some(new_sym);
131        }
132
133        // Check for cross-package references (e.g. `other.foo`)
134        if let Some(dot_idx) = name.rfind('.') {
135            let prefix = &name[..dot_idx];
136            if self.used_pkgs.contains(prefix) {
137                // It's a valid external reference, keep it as is.
138                self.cache.insert(sym, sym);
139                return Some(sym);
140            }
141        }
142
143        // Not defined locally, no dot. Must be a builtin or global?
144        self.cache.insert(sym, sym);
145        Some(sym)
146    }
147
148    fn rewrite_decl(&mut self, decl: &'a ast::Decl<'a>) -> Option<&'a ast::Decl<'a>> {
149        let pred_name = self.arena.predicate_name(decl.atom.sym).unwrap_or("");
150        if pred_name == "Package" || pred_name == "Use" {
151            // Remove Package and Use declarations from the rewritten unit
152            return None;
153        }
154
155        let new_atom = self.rewrite_atom(decl.atom)?;
156
157        // Rewrite bounds if necessary
158        let bounds = if let Some(bs) = decl.bounds {
159            let mut new_bounds = Vec::new();
160            for &b in bs {
161                let new_base_terms: Vec<&ast::BaseTerm> = b
162                    .base_terms
163                    .iter()
164                    .map(|&t| {
165                        // Check if it's a name constant that needs rewriting (e.g. type names like /foo -> /pkg.foo)
166                        if let ast::BaseTerm::Const(ast::Const::Name(name_idx)) = t {
167                            let name = self.arena.lookup_name(*name_idx).unwrap_or("");
168                            // Check if name corresponds to a defined predicate
169                            if let Some(pred_idx) = self.arena.lookup_predicate_sym(*name_idx)
170                                && self.defined_preds.contains(&pred_idx)
171                            {
172                                let new_name = format!("{}.{}", self.pkg_name, name);
173                                let new_name_idx = self.arena.intern(&new_name);
174                                return &*self
175                                    .arena
176                                    .alloc(ast::BaseTerm::Const(ast::Const::Name(new_name_idx)));
177                            }
178                        }
179                        t
180                    })
181                    .collect();
182                new_bounds.push(&*self.arena.alloc(ast::BoundDecl {
183                    base_terms: self.arena.alloc_slice_copy(&new_base_terms),
184                }));
185            }
186            Some(self.arena.alloc_slice_copy(&new_bounds) as &'a [&'a ast::BoundDecl<'a>])
187        } else {
188            None
189        };
190
191        Some(self.arena.alloc(ast::Decl {
192            atom: new_atom,
193            descr: decl.descr, // Descriptions usually don't contain predicates to rename?
194            bounds,
195            constraints: decl.constraints,
196            is_temporal: decl.is_temporal,
197        }))
198    }
199
200    fn rewrite_clause(&mut self, clause: &'a ast::Clause<'a>) -> Option<&'a ast::Clause<'a>> {
201        let head = self.rewrite_atom(clause.head)?;
202        let mut premises = Vec::new();
203        for &premise in clause.premises {
204            match premise {
205                ast::Term::Atom(a) => {
206                    premises.push(&*self.arena.alloc(ast::Term::Atom(self.rewrite_atom(a)?)));
207                }
208                ast::Term::NegAtom(a) => {
209                    premises.push(&*self.arena.alloc(ast::Term::NegAtom(self.rewrite_atom(a)?)));
210                }
211                ast::Term::TemporalAtom(a, interval) => {
212                    premises.push(&*self.arena.alloc(ast::Term::TemporalAtom(self.rewrite_atom(a)?, *interval)));
213                }
214                _ => premises.push(premise),
215            }
216        }
217
218        Some(self.arena.alloc(ast::Clause {
219            head,
220            head_time: clause.head_time,
221            premises: self.arena.alloc_slice_copy(&premises),
222            transform: clause.transform,
223        }))
224    }
225
226    fn rewrite_atom(&mut self, atom: &'a ast::Atom<'a>) -> Option<&'a ast::Atom<'a>> {
227        let new_sym = self.rename_pred(atom.sym)?;
228        Some(self.arena.atom(new_sym, atom.args))
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use googletest::prelude::*;
236
237    #[test]
238    fn test_rename_simple() {
239        let arena = ast::Arena::new_with_global_interner();
240
241        // Package pkg!
242        // foo(X) :- bar(X).
243        // bar(1).
244
245        let pkg_sym = arena.predicate_sym("Package", Some(0));
246        let name_sym = arena.predicate_sym("name", Some(1));
247        let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("pkg")));
248        let pkg_decl = arena.alloc(ast::Decl {
249            atom: arena.atom(pkg_sym, &[]),
250            descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
251            bounds: None,
252            constraints: None,
253            is_temporal: false,
254        });
255
256        let foo_sym = arena.predicate_sym("foo", Some(1));
257        let bar_sym = arena.predicate_sym("bar", Some(1));
258        let var_x = arena.variable("X");
259        let const_1 = arena.const_(ast::Const::Number(1));
260
261        let clause1 = arena.alloc(ast::Clause {
262            head: arena.atom(foo_sym, &[var_x]),
263            head_time: None,
264            premises: arena
265                .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(bar_sym, &[var_x])))]),
266            transform: &[],
267        });
268
269        let clause2 = arena.alloc(ast::Clause {
270            head: arena.atom(bar_sym, &[const_1]),
271            head_time: None,
272            premises: &[],
273            transform: &[],
274        });
275
276        let unit = ast::Unit {
277            decls: arena.alloc_slice_copy(&[pkg_decl]),
278            clauses: arena.alloc_slice_copy(&[clause1, clause2]),
279        };
280
281        let new_unit = rewrite_unit(&arena, &unit);
282
283        // Package decl should be removed
284        assert_that!(new_unit.decls.len(), eq(0));
285        assert_that!(new_unit.clauses.len(), eq(2));
286
287        let c1 = new_unit.clauses[0];
288        let c2 = new_unit.clauses[1];
289
290        // Check head of c1: foo -> pkg.foo
291        let head_name = arena.predicate_name(c1.head.sym).unwrap();
292        assert_that!(head_name, eq("pkg.foo"));
293
294        // Check premise of c1: bar -> pkg.bar
295        if let ast::Term::Atom(a) = c1.premises[0] {
296            let p_name = arena.predicate_name(a.sym).unwrap();
297            assert_that!(p_name, eq("pkg.bar"));
298        } else {
299            panic!("Expected Atom premise");
300        }
301
302        // Check head of c2: bar -> pkg.bar
303        let head_name_2 = arena.predicate_name(c2.head.sym).unwrap();
304        assert_that!(head_name_2, eq("pkg.bar"));
305    }
306
307    #[test]
308    fn test_rename_with_use() {
309        let arena = ast::Arena::new_with_global_interner();
310
311        // Package pkg!
312        // Use other!
313        // foo(X) :- other.bar(X).
314
315        let pkg_sym = arena.predicate_sym("Package", Some(0));
316        let use_sym = arena.predicate_sym("Use", Some(0));
317        let name_sym = arena.predicate_sym("name", Some(1));
318
319        let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("pkg")));
320        let pkg_decl = arena.alloc(ast::Decl {
321            atom: arena.atom(pkg_sym, &[]),
322            descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
323            bounds: None,
324            constraints: None,
325            is_temporal: false,
326        });
327
328        let other_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String("other")));
329        let use_decl = arena.alloc(ast::Decl {
330            atom: arena.atom(use_sym, &[]),
331            descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[other_name])]),
332            bounds: None,
333            constraints: None,
334            is_temporal: false,
335        });
336
337        let foo_sym = arena.predicate_sym("foo", Some(1));
338        let other_bar_sym = arena.predicate_sym("other.bar", Some(1));
339        let var_x = arena.variable("X");
340
341        let clause1 = arena.alloc(ast::Clause {
342            head: arena.atom(foo_sym, &[var_x]),
343            head_time: None,
344            premises: arena.alloc_slice_copy(&[
345                arena.alloc(ast::Term::Atom(arena.atom(other_bar_sym, &[var_x])))
346            ]),
347            transform: &[],
348        });
349
350        let unit = ast::Unit {
351            decls: arena.alloc_slice_copy(&[pkg_decl, use_decl]),
352            clauses: arena.alloc_slice_copy(&[clause1]),
353        };
354
355        let new_unit = rewrite_unit(&arena, &unit);
356
357        assert_that!(new_unit.clauses.len(), eq(1));
358        let c1 = new_unit.clauses[0];
359
360        // foo -> pkg.foo
361        let head_name = arena.predicate_name(c1.head.sym).unwrap();
362        assert_that!(head_name, eq("pkg.foo"));
363
364        // other.bar -> other.bar (unchanged because 'other' is Used)
365        if let ast::Term::Atom(a) = c1.premises[0] {
366            let p_name = arena.predicate_name(a.sym).unwrap();
367            assert_that!(p_name, eq("other.bar"));
368        } else {
369            panic!("Expected Atom premise");
370        }
371    }
372
373    fn make_pkg_decl<'a>(arena: &'a ast::Arena, name: &str) -> &'a ast::Decl<'a> {
374        let pkg_sym = arena.predicate_sym("Package", Some(0));
375        let name_sym = arena.predicate_sym("name", Some(1));
376        let pkg_name = arena.alloc(ast::BaseTerm::Const(ast::Const::String(
377            arena.alloc_str(name),
378        )));
379        arena.alloc(ast::Decl {
380            atom: arena.atom(pkg_sym, &[]),
381            descr: arena.alloc_slice_copy(&[arena.atom(name_sym, &[pkg_name])]),
382            bounds: None,
383            constraints: None,
384            is_temporal: false,
385        })
386    }
387
388    #[test]
389    fn test_go_case_no_package() {
390        // "no package name, clauses are not rewritten"
391        let arena = ast::Arena::new_with_global_interner();
392        let clause = arena.alloc(ast::Clause {
393            head: arena.atom(arena.predicate_sym("clause_defined_here", None), &[]),
394            head_time: None,
395            premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(
396                arena.atom(arena.predicate_sym("other_clause", None), &[]),
397            ))]),
398            transform: &[],
399        });
400        let unit = ast::Unit {
401            decls: &[],
402            clauses: arena.alloc_slice_copy(&[clause]),
403        };
404        let new_unit = rewrite_unit(&arena, &unit);
405        let head = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
406        assert_that!(head, eq("clause_defined_here"));
407        if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
408            let p = arena.predicate_name(a.sym).unwrap();
409            assert_that!(p, eq("other_clause"));
410        }
411    }
412
413    #[test]
414    fn test_go_case_external_refs() {
415        // "references to predicates outside the package are left as-is"
416        let arena = ast::Arena::new_with_global_interner();
417        // Package foo.bar!
418        let pkg_decl = make_pkg_decl(&arena, "foo.bar");
419
420        // clause_defined_here :- other_clause.
421        // (other_clause is NOT defined locally)
422        let clause = arena.alloc(ast::Clause {
423            head: arena.atom(arena.predicate_sym("clause_defined_here", None), &[]),
424            head_time: None,
425            premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(
426                arena.atom(arena.predicate_sym("other_clause", None), &[]),
427            ))]),
428            transform: &[],
429        });
430
431        let unit = ast::Unit {
432            decls: arena.alloc_slice_copy(&[pkg_decl]),
433            clauses: arena.alloc_slice_copy(&[clause]),
434        };
435        let new_unit = rewrite_unit(&arena, &unit);
436
437        // head -> foo.bar.clause_defined_here
438        let head = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
439        assert_that!(head, eq("foo.bar.clause_defined_here"));
440
441        // premise -> other_clause (unchanged)
442        if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
443            let p = arena.predicate_name(a.sym).unwrap();
444            assert_that!(p, eq("other_clause"));
445        }
446    }
447
448    #[test]
449    fn test_go_case_rewritten_local() {
450        // "clauses defined in this package are rewritten"
451        let arena = ast::Arena::new_with_global_interner();
452        let pkg_decl = make_pkg_decl(&arena, "foo.bar");
453
454        let defined_sym = arena.predicate_sym("clause_defined_here", None);
455        let other_sym = arena.predicate_sym("other_clause", None);
456
457        // other_clause().
458        let clause1 = arena.alloc(ast::Clause {
459            head: arena.atom(other_sym, &[]),
460            head_time: None,
461            premises: &[],
462            transform: &[],
463        });
464
465        // clause_defined_here() :- other_clause().
466        let clause2 = arena.alloc(ast::Clause {
467            head: arena.atom(defined_sym, &[]),
468            head_time: None,
469            premises: arena
470                .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(other_sym, &[])))]),
471            transform: &[],
472        });
473
474        let unit = ast::Unit {
475            decls: arena.alloc_slice_copy(&[pkg_decl]),
476            clauses: arena.alloc_slice_copy(&[clause1, clause2]),
477        };
478        let new_unit = rewrite_unit(&arena, &unit);
479
480        // clause1 head: foo.bar.other_clause
481        let h1 = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
482        assert_that!(h1, eq("foo.bar.other_clause"));
483
484        // clause2 head: foo.bar.clause_defined_here
485        let h2 = arena.predicate_name(new_unit.clauses[1].head.sym).unwrap();
486        assert_that!(h2, eq("foo.bar.clause_defined_here"));
487
488        // clause2 premise: foo.bar.other_clause
489        if let ast::Term::Atom(a) = new_unit.clauses[1].premises[0] {
490            let p = arena.predicate_name(a.sym).unwrap();
491            assert_that!(p, eq("foo.bar.other_clause"));
492        }
493    }
494
495    #[test]
496    fn test_go_case_negation() {
497        // "clause with a negation is rewritten"
498        let arena = ast::Arena::new_with_global_interner();
499        let pkg_decl = make_pkg_decl(&arena, "foo.bar");
500
501        let defined_sym = arena.predicate_sym("clause_defined_here", None);
502        let other_sym = arena.predicate_sym("other_clause", None);
503
504        // other_clause(). (Needs to be defined to trigger renaming)
505        let clause1 = arena.alloc(ast::Clause {
506            head: arena.atom(other_sym, &[]),
507            head_time: None,
508            premises: &[],
509            transform: &[],
510        });
511
512        // clause_defined_here() :- !other_clause().
513        let clause2 = arena.alloc(ast::Clause {
514            head: arena.atom(defined_sym, &[]),
515            head_time: None,
516            premises: arena
517                .alloc_slice_copy(&[arena.alloc(ast::Term::NegAtom(arena.atom(other_sym, &[])))]),
518            transform: &[],
519        });
520
521        let unit = ast::Unit {
522            decls: arena.alloc_slice_copy(&[pkg_decl]),
523            clauses: arena.alloc_slice_copy(&[clause1, clause2]),
524        };
525        let new_unit = rewrite_unit(&arena, &unit);
526
527        if let ast::Term::NegAtom(a) = new_unit.clauses[1].premises[0] {
528            let p = arena.predicate_name(a.sym).unwrap();
529            assert_that!(p, eq("foo.bar.other_clause"));
530        } else {
531            panic!("Expected NegAtom");
532        }
533    }
534
535    #[test]
536    fn test_go_case_decl_only() {
537        // "clauses are also rewritten if the decl was declared in this package"
538        let arena = ast::Arena::new_with_global_interner();
539        let pkg_decl = make_pkg_decl(&arena, "foo.bar");
540
541        let clause_sym = arena.predicate_sym("clause", None);
542        let decl_sym = arena.predicate_sym("from_decl", None);
543
544        // Decl from_decl.
545        let decl = arena.alloc(ast::Decl {
546            atom: arena.atom(decl_sym, &[]),
547            descr: &[],
548            bounds: None,
549            constraints: None,
550            is_temporal: false,
551        });
552
553        // clause() :- from_decl().
554        let clause = arena.alloc(ast::Clause {
555            head: arena.atom(clause_sym, &[]),
556            head_time: None,
557            premises: arena
558                .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(decl_sym, &[])))]),
559            transform: &[],
560        });
561
562        let unit = ast::Unit {
563            decls: arena.alloc_slice_copy(&[pkg_decl, decl]),
564            clauses: arena.alloc_slice_copy(&[clause]),
565        };
566        let new_unit = rewrite_unit(&arena, &unit);
567
568        // clause -> foo.bar.clause
569        let h = arena.predicate_name(new_unit.clauses[0].head.sym).unwrap();
570        assert_that!(h, eq("foo.bar.clause"));
571
572        // from_decl -> foo.bar.from_decl
573        if let ast::Term::Atom(a) = new_unit.clauses[0].premises[0] {
574            let p = arena.predicate_name(a.sym).unwrap();
575            assert_that!(p, eq("foo.bar.from_decl"));
576        }
577
578        // Decl atom also rewritten
579        let d_name = arena.predicate_name(new_unit.decls[0].atom.sym).unwrap();
580        assert_that!(d_name, eq("foo.bar.from_decl"));
581    }
582}