Skip to main content

mangle_analysis/
lowering.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fxhash::FxHashMap;
16use mangle_ast as ast;
17use mangle_ir::{Inst, InstId, Ir};
18
19pub struct LoweringContext<'a> {
20    arena: &'a ast::Arena,
21    ir: Ir,
22    // Scope-specific maps
23    vars: FxHashMap<ast::VariableIndex, InstId>,
24}
25
26impl<'a> LoweringContext<'a> {
27    pub fn new(arena: &'a ast::Arena) -> Self {
28        Self {
29            arena,
30            ir: Ir::new(),
31            vars: FxHashMap::default(),
32        }
33    }
34
35    pub fn lower_unit(mut self, unit: &ast::Unit) -> Ir {
36        for decl in unit.decls {
37            self.lower_decl(decl);
38        }
39        for clause in unit.clauses {
40            self.lower_clause(clause);
41        }
42        self.ir
43    }
44
45    fn lower_decl(&mut self, decl: &ast::Decl) -> InstId {
46        self.vars.clear();
47
48        // Track temporal predicates from declarations
49        if decl.is_temporal {
50            let pred_name = self
51                .arena
52                .predicate_name(decl.atom.sym)
53                .unwrap_or("unknown_pred")
54                .to_string();
55            let name_id = self.ir.intern_name(pred_name);
56            self.ir.temporal_predicates.insert(name_id);
57        }
58
59        let atom = self.lower_atom(decl.atom);
60        let descr: Vec<InstId> = decl.descr.iter().map(|a| self.lower_atom(a)).collect();
61        let bounds: Vec<InstId> = if let Some(bs) = decl.bounds {
62            bs.iter().map(|b| self.lower_bound_decl(b)).collect()
63        } else {
64            Vec::new()
65        };
66        let constraints = decl.constraints.map(|c| self.lower_constraints(c));
67
68        self.ir.add_inst(Inst::Decl {
69            atom,
70            descr,
71            bounds,
72            constraints,
73        })
74    }
75
76    fn lower_clause(&mut self, clause: &ast::Clause) -> InstId {
77        self.vars.clear();
78
79        let head = if let Some(interval) = &clause.head_time {
80            // Temporal head: append synthetic interval columns
81            self.lower_temporal_atom(clause.head, interval)
82        } else {
83            self.lower_atom(clause.head)
84        };
85        let premises: Vec<InstId> = clause.premises.iter().map(|t| self.lower_term(t)).collect();
86        let transform: Vec<InstId> = clause
87            .transform
88            .iter()
89            .map(|t| self.lower_transform(t))
90            .collect();
91
92        self.ir.add_inst(Inst::Rule {
93            head,
94            premises,
95            transform,
96        })
97    }
98
99    /// Lower an atom with temporal interval as 2 extra columns.
100    fn lower_temporal_atom(&mut self, atom: &ast::Atom, interval: &ast::Interval) -> InstId {
101        let predicate_name = self
102            .arena
103            .predicate_name(atom.sym)
104            .unwrap_or("unknown_pred")
105            .to_string();
106        let predicate = self.ir.intern_name(predicate_name);
107        // Track this predicate as temporal
108        self.ir.temporal_predicates.insert(predicate);
109        let mut args: Vec<InstId> = atom
110            .args
111            .iter()
112            .map(|arg| self.lower_base_term(arg))
113            .collect();
114        args.push(self.lower_temporal_bound(&interval.start));
115        args.push(self.lower_temporal_bound(&interval.end));
116        self.ir.add_inst(Inst::Atom { predicate, args })
117    }
118
119    fn lower_atom(&mut self, atom: &ast::Atom) -> InstId {
120        let predicate_name = self
121            .arena
122            .predicate_name(atom.sym)
123            .unwrap_or("unknown_pred")
124            .to_string();
125        let predicate = self.ir.intern_name(predicate_name);
126        let args: Vec<InstId> = atom
127            .args
128            .iter()
129            .map(|arg| self.lower_base_term(arg))
130            .collect();
131        self.ir.add_inst(Inst::Atom { predicate, args })
132    }
133
134    fn lower_term(&mut self, term: &ast::Term) -> InstId {
135        match term {
136            ast::Term::Atom(a) => self.lower_atom(a),
137            ast::Term::NegAtom(a) => {
138                let atom = self.lower_atom(a);
139                self.ir.add_inst(Inst::NegAtom(atom))
140            }
141            ast::Term::Eq(l, r) => {
142                let left = self.lower_base_term(l);
143                let right = self.lower_base_term(r);
144                self.ir.add_inst(Inst::Eq(left, right))
145            }
146            ast::Term::Ineq(l, r) => {
147                let left = self.lower_base_term(l);
148                let right = self.lower_base_term(r);
149                self.ir.add_inst(Inst::Ineq(left, right))
150            }
151            ast::Term::TemporalAtom(a, interval) => self.lower_temporal_atom(a, interval),
152        }
153    }
154
155    fn lower_base_term(&mut self, term: &ast::BaseTerm) -> InstId {
156        match term {
157            ast::BaseTerm::Const(c) => self.lower_const(c),
158            ast::BaseTerm::Variable(v) => {
159                if let Some(id) = self.vars.get(v) {
160                    *id
161                } else {
162                    let name_str = if v.0 == 0 {
163                        "_".to_string()
164                    } else {
165                        self.arena
166                            .lookup_name(v.0)
167                            .unwrap_or("unknown_var")
168                            .to_string()
169                    };
170                    let name = self.ir.intern_name(name_str);
171                    let id = self.ir.add_inst(Inst::Var(name));
172                    // Don't cache wildcard?
173                    if v.0 != 0 {
174                        self.vars.insert(*v, id);
175                    }
176                    id
177                }
178            }
179            ast::BaseTerm::ApplyFn(f, args) => {
180                let function_str = self
181                    .arena
182                    .function_name(*f)
183                    .unwrap_or("unknown_fn")
184                    .to_string();
185                let function = self.ir.intern_name(function_str);
186                let args = args.iter().map(|a| self.lower_base_term(a)).collect();
187                self.ir.add_inst(Inst::ApplyFn { function, args })
188            }
189        }
190    }
191
192    fn lower_const(&mut self, c: &ast::Const) -> InstId {
193        match c {
194            ast::Const::Name(n) => {
195                let name_str = self
196                    .arena
197                    .lookup_name(*n)
198                    .unwrap_or("unknown_name")
199                    .to_string();
200                let name = self.ir.intern_name(name_str);
201                self.ir.add_inst(Inst::Name(name))
202            }
203            ast::Const::Bool(b) => self.ir.add_inst(Inst::Bool(*b)),
204            ast::Const::Number(n) => self.ir.add_inst(Inst::Number(*n)),
205            ast::Const::Float(f) => self.ir.add_inst(Inst::Float(*f)),
206            ast::Const::Time(t) => self.ir.add_inst(Inst::Time(*t)),
207            ast::Const::Duration(d) => self.ir.add_inst(Inst::Duration(*d)),
208            ast::Const::String(s) => {
209                let id = self.ir.intern_string(*s);
210                self.ir.add_inst(Inst::String(id))
211            }
212            ast::Const::Bytes(b) => self.ir.add_inst(Inst::Bytes(b.to_vec())),
213            ast::Const::List(l) => {
214                let args = l.iter().map(|c| self.lower_const(c)).collect();
215                self.ir.add_inst(Inst::List(args))
216            }
217            ast::Const::Map { keys, values } => {
218                let keys = keys.iter().map(|c| self.lower_const(c)).collect();
219                let values = values.iter().map(|c| self.lower_const(c)).collect();
220                self.ir.add_inst(Inst::Map { keys, values })
221            }
222            ast::Const::Struct { fields, values } => {
223                let fields = fields
224                    .iter()
225                    .map(|s| self.ir.intern_name(s.to_string()))
226                    .collect();
227                let values = values.iter().map(|c| self.lower_const(c)).collect();
228                self.ir.add_inst(Inst::Struct { fields, values })
229            }
230        }
231    }
232
233    fn lower_temporal_bound(&mut self, bound: &ast::TemporalBound) -> InstId {
234        match bound {
235            ast::TemporalBound::Timestamp(nanos) => self.ir.add_inst(Inst::Time(*nanos)),
236            ast::TemporalBound::Variable(var_idx) => {
237                if let Some(id) = self.vars.get(var_idx) {
238                    *id
239                } else {
240                    let name_str = self
241                        .arena
242                        .lookup_name(var_idx.0)
243                        .unwrap_or("unknown_var")
244                        .to_string();
245                    let name = self.ir.intern_name(name_str);
246                    let id = self.ir.add_inst(Inst::Var(name));
247                    if var_idx.0 != 0 {
248                        self.vars.insert(*var_idx, id);
249                    }
250                    id
251                }
252            }
253            ast::TemporalBound::NegInf => self.ir.add_inst(Inst::Time(i64::MIN)),
254            ast::TemporalBound::PosInf => self.ir.add_inst(Inst::Time(i64::MAX)),
255        }
256    }
257
258    fn lower_transform(&mut self, t: &ast::TransformStmt) -> InstId {
259        let var = t.var.map(|s| self.ir.intern_name(s.to_string()));
260        let app = self.lower_base_term(t.app);
261        self.ir.add_inst(Inst::Transform { var, app })
262    }
263
264    fn lower_bound_decl(&mut self, b: &ast::BoundDecl) -> InstId {
265        let base_terms = b
266            .base_terms
267            .iter()
268            .map(|t| self.lower_base_term(t))
269            .collect();
270        self.ir.add_inst(Inst::BoundDecl { base_terms })
271    }
272
273    fn lower_constraints(&mut self, c: &ast::Constraints) -> InstId {
274        let consequences = c.consequences.iter().map(|a| self.lower_atom(a)).collect();
275        let alternatives = c
276            .alternatives
277            .iter()
278            .map(|alt| alt.iter().map(|a| self.lower_atom(a)).collect())
279            .collect();
280        self.ir.add_inst(Inst::Constraints {
281            consequences,
282            alternatives,
283        })
284    }
285}