Skip to main content

mangle_analysis/
lowering.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fxhash::FxHashMap;
16use mangle_ast as ast;
17use mangle_ir::{Inst, InstId, Ir};
18
19pub struct LoweringContext<'a> {
20    arena: &'a ast::Arena,
21    ir: Ir,
22    // Scope-specific maps
23    vars: FxHashMap<ast::VariableIndex, InstId>,
24}
25
26impl<'a> LoweringContext<'a> {
27    pub fn new(arena: &'a ast::Arena) -> Self {
28        Self {
29            arena,
30            ir: Ir::new(),
31            vars: FxHashMap::default(),
32        }
33    }
34
35    pub fn lower_unit(mut self, unit: &ast::Unit) -> Ir {
36        for decl in unit.decls {
37            self.lower_decl(decl);
38        }
39        for clause in unit.clauses {
40            self.lower_clause(clause);
41        }
42        self.ir
43    }
44
45    fn lower_decl(&mut self, decl: &ast::Decl) -> InstId {
46        self.vars.clear();
47
48        let atom = self.lower_atom(decl.atom);
49        let descr: Vec<InstId> = decl.descr.iter().map(|a| self.lower_atom(a)).collect();
50        let bounds: Vec<InstId> = if let Some(bs) = decl.bounds {
51            bs.iter().map(|b| self.lower_bound_decl(b)).collect()
52        } else {
53            Vec::new()
54        };
55        let constraints = decl.constraints.map(|c| self.lower_constraints(c));
56
57        self.ir.add_inst(Inst::Decl {
58            atom,
59            descr,
60            bounds,
61            constraints,
62        })
63    }
64
65    fn lower_clause(&mut self, clause: &ast::Clause) -> InstId {
66        self.vars.clear();
67
68        let head = self.lower_atom(clause.head);
69        let premises: Vec<InstId> = clause.premises.iter().map(|t| self.lower_term(t)).collect();
70        let transform: Vec<InstId> = clause
71            .transform
72            .iter()
73            .map(|t| self.lower_transform(t))
74            .collect();
75
76        self.ir.add_inst(Inst::Rule {
77            head,
78            premises,
79            transform,
80        })
81    }
82
83    fn lower_atom(&mut self, atom: &ast::Atom) -> InstId {
84        let predicate_name = self
85            .arena
86            .predicate_name(atom.sym)
87            .unwrap_or("unknown_pred")
88            .to_string();
89        let predicate = self.ir.intern_name(predicate_name);
90        let args: Vec<InstId> = atom
91            .args
92            .iter()
93            .map(|arg| self.lower_base_term(arg))
94            .collect();
95        self.ir.add_inst(Inst::Atom { predicate, args })
96    }
97
98    fn lower_term(&mut self, term: &ast::Term) -> InstId {
99        match term {
100            ast::Term::Atom(a) => self.lower_atom(a),
101            ast::Term::NegAtom(a) => {
102                let atom = self.lower_atom(a);
103                self.ir.add_inst(Inst::NegAtom(atom))
104            }
105            ast::Term::Eq(l, r) => {
106                let left = self.lower_base_term(l);
107                let right = self.lower_base_term(r);
108                self.ir.add_inst(Inst::Eq(left, right))
109            }
110            ast::Term::Ineq(l, r) => {
111                let left = self.lower_base_term(l);
112                let right = self.lower_base_term(r);
113                self.ir.add_inst(Inst::Ineq(left, right))
114            }
115        }
116    }
117
118    fn lower_base_term(&mut self, term: &ast::BaseTerm) -> InstId {
119        match term {
120            ast::BaseTerm::Const(c) => self.lower_const(c),
121            ast::BaseTerm::Variable(v) => {
122                if let Some(id) = self.vars.get(v) {
123                    *id
124                } else {
125                    let name_str = if v.0 == 0 {
126                        "_".to_string()
127                    } else {
128                        self.arena
129                            .lookup_name(v.0)
130                            .unwrap_or("unknown_var")
131                            .to_string()
132                    };
133                    let name = self.ir.intern_name(name_str);
134                    let id = self.ir.add_inst(Inst::Var(name));
135                    // Don't cache wildcard?
136                    if v.0 != 0 {
137                        self.vars.insert(*v, id);
138                    }
139                    id
140                }
141            }
142            ast::BaseTerm::ApplyFn(f, args) => {
143                let function_str = self
144                    .arena
145                    .function_name(*f)
146                    .unwrap_or("unknown_fn")
147                    .to_string();
148                let function = self.ir.intern_name(function_str);
149                let args = args.iter().map(|a| self.lower_base_term(a)).collect();
150                self.ir.add_inst(Inst::ApplyFn { function, args })
151            }
152        }
153    }
154
155    fn lower_const(&mut self, c: &ast::Const) -> InstId {
156        match c {
157            ast::Const::Name(n) => {
158                let name_str = self
159                    .arena
160                    .lookup_name(*n)
161                    .unwrap_or("unknown_name")
162                    .to_string();
163                let name = self.ir.intern_name(name_str);
164                self.ir.add_inst(Inst::Name(name))
165            }
166            ast::Const::Bool(b) => self.ir.add_inst(Inst::Bool(*b)),
167            ast::Const::Number(n) => self.ir.add_inst(Inst::Number(*n)),
168            ast::Const::Float(f) => self.ir.add_inst(Inst::Float(*f)),
169            ast::Const::String(s) => {
170                let id = self.ir.intern_string(*s);
171                self.ir.add_inst(Inst::String(id))
172            }
173            ast::Const::Bytes(b) => self.ir.add_inst(Inst::Bytes(b.to_vec())),
174            ast::Const::List(l) => {
175                let args = l.iter().map(|c| self.lower_const(c)).collect();
176                self.ir.add_inst(Inst::List(args))
177            }
178            ast::Const::Map { keys, values } => {
179                let keys = keys.iter().map(|c| self.lower_const(c)).collect();
180                let values = values.iter().map(|c| self.lower_const(c)).collect();
181                self.ir.add_inst(Inst::Map { keys, values })
182            }
183            ast::Const::Struct { fields, values } => {
184                let fields = fields
185                    .iter()
186                    .map(|s| self.ir.intern_name(s.to_string()))
187                    .collect();
188                let values = values.iter().map(|c| self.lower_const(c)).collect();
189                self.ir.add_inst(Inst::Struct { fields, values })
190            }
191        }
192    }
193
194    fn lower_transform(&mut self, t: &ast::TransformStmt) -> InstId {
195        let var = t.var.map(|s| self.ir.intern_name(s.to_string()));
196        let app = self.lower_base_term(t.app);
197        self.ir.add_inst(Inst::Transform { var, app })
198    }
199
200    fn lower_bound_decl(&mut self, b: &ast::BoundDecl) -> InstId {
201        let base_terms = b
202            .base_terms
203            .iter()
204            .map(|t| self.lower_base_term(t))
205            .collect();
206        self.ir.add_inst(Inst::BoundDecl { base_terms })
207    }
208
209    fn lower_constraints(&mut self, c: &ast::Constraints) -> InstId {
210        let consequences = c.consequences.iter().map(|a| self.lower_atom(a)).collect();
211        let alternatives = c
212            .alternatives
213            .iter()
214            .map(|alt| alt.iter().map(|a| self.lower_atom(a)).collect())
215            .collect();
216        self.ir.add_inst(Inst::Constraints {
217            consequences,
218            alternatives,
219        })
220    }
221}