prune_lang/logic/
transform.rs

1use crate::syntax::ast;
2
3use super::optimize;
4use super::*;
5
6pub struct Translater {
7    vars: Vec<(Ident, TypeId)>,
8}
9
10impl Translater {
11    fn new() -> Translater {
12        Translater { vars: Vec::new() }
13    }
14
15    fn fresh_var(&mut self) -> Ident {
16        let var = Ident::fresh(&"x");
17        let typ = TypeId::Var(Ident::fresh(&"a"));
18        self.vars.push((var, typ));
19        var
20    }
21
22    fn translate_func(&mut self, func: &ast::FuncDecl) -> PredDecl {
23        self.vars = Vec::new();
24        let (term, goal) = self.translate_expr(&func.body);
25        let name = func.name.ident;
26        let polys = func.polys.iter().map(|poly| poly.ident).collect();
27        let mut pars: Vec<(Ident, TypeId)> = func
28            .pars
29            .iter()
30            .map(|(var, typ)| (var.ident, translate_type(typ)))
31            .collect();
32        let res = Ident::fresh(&"res");
33        pars.push((res, translate_type(&func.res)));
34        let goal = Goal::And(vec![Goal::Eq(Term::Var(res), term), goal]);
35        PredDecl {
36            name,
37            polys,
38            pars,
39            vars: self.vars.clone(),
40            goal: optimize::goal_optimize(goal),
41        }
42    }
43
44    fn translate_expr(&mut self, expr: &ast::Expr) -> (TermId, Goal) {
45        match expr {
46            ast::Expr::Lit { lit, span: _ } => (Term::Lit(*lit), Goal::Lit(true)),
47            ast::Expr::Var { var, span: _ } => (Term::Var(var.ident), Goal::Lit(true)),
48            ast::Expr::Prim {
49                prim,
50                args,
51                span: _,
52            } => {
53                let x = self.fresh_var();
54                let (mut terms, mut goals): (Vec<TermId>, Vec<Goal>) =
55                    args.iter().map(|arg| self.translate_expr(arg)).unzip();
56                terms.push(Term::Var(x));
57                let terms = terms
58                    .into_iter()
59                    .map(|term| term.to_atom().unwrap())
60                    .collect();
61                goals.push(Goal::Prim(*prim, terms));
62                (Term::Var(x), Goal::And(goals))
63            }
64            ast::Expr::Cons {
65                cons,
66                flds,
67                span: _,
68            } => {
69                let (flds, goals): (Vec<TermId>, Vec<Goal>) =
70                    flds.iter().map(|fld| self.translate_expr(fld)).unzip();
71                (
72                    Term::Cons(OptCons::Some(cons.ident), flds),
73                    Goal::And(goals),
74                )
75            }
76            ast::Expr::Tuple { flds, span: _ } => {
77                let (flds, goals): (Vec<TermId>, Vec<Goal>) =
78                    flds.iter().map(|fld| self.translate_expr(fld)).unzip();
79                (Term::Cons(OptCons::None, flds), Goal::And(goals))
80            }
81            ast::Expr::Match {
82                expr,
83                brchs,
84                span: _,
85            } => {
86                let x = self.fresh_var();
87                let (term0, goal0) = self.translate_expr(expr);
88                let mut goals = Vec::new();
89                for (patn, expr) in brchs {
90                    let patn_term = self.translate_patn(patn);
91                    let (term1, goal1) = self.translate_expr(expr);
92                    goals.push(Goal::And(vec![
93                        Goal::Eq(term0.clone(), patn_term),
94                        goal1,
95                        Goal::Eq(Term::Var(x), term1),
96                    ]));
97                }
98                (Term::Var(x), Goal::And(vec![goal0, Goal::Or(goals)]))
99            }
100            ast::Expr::Let {
101                patn,
102                expr,
103                cont,
104                span: _,
105            } => {
106                let (term0, goal0) = self.translate_expr(expr);
107                let patn_term = self.translate_patn(patn);
108                let (term1, goal1) = self.translate_expr(cont);
109                (
110                    term1,
111                    Goal::And(vec![goal0, Goal::Eq(term0, patn_term), goal1]),
112                )
113            }
114            ast::Expr::App {
115                func,
116                args,
117                span: _,
118            } => {
119                let x = self.fresh_var();
120                let (mut terms, mut goals): (Vec<TermId>, Vec<Goal>) =
121                    args.iter().map(|arg| self.translate_expr(arg)).unzip();
122                terms.push(Term::Var(x));
123                goals.push(Goal::Call(func.ident, Vec::new(), terms));
124                (Term::Var(x), Goal::And(goals))
125            }
126            ast::Expr::Ifte {
127                cond,
128                then,
129                els,
130                span: _,
131            } => {
132                let x = self.fresh_var();
133                let (term0, goal0) = self.translate_expr(cond);
134                let (term1, goal1) = self.translate_expr(then);
135                let (term2, goal2) = self.translate_expr(els);
136                match term0 {
137                    Term::Var(var) => {
138                        let goal = Goal::And(vec![
139                            goal0,
140                            Goal::Or(vec![
141                                Goal::And(vec![
142                                    Goal::Eq(Term::Var(var), Term::Lit(LitVal::Bool(true))),
143                                    goal1,
144                                    Goal::Eq(Term::Var(x), term1),
145                                ]),
146                                Goal::And(vec![
147                                    Goal::Eq(Term::Var(var), Term::Lit(LitVal::Bool(false))),
148                                    goal2,
149                                    Goal::Eq(Term::Var(x), term2),
150                                ]),
151                            ]),
152                        ]);
153                        (Term::Var(x), goal)
154                    }
155                    Term::Lit(LitVal::Bool(true)) => (term1, Goal::And(vec![goal0, goal1])),
156                    Term::Lit(LitVal::Bool(false)) => (term2, Goal::And(vec![goal0, goal2])),
157                    _ => {
158                        unreachable!();
159                    }
160                }
161            }
162            ast::Expr::Cond { brchs, span: _ } => {
163                let x = self.fresh_var();
164                let mut goals = Vec::new();
165                for (cond, body) in brchs {
166                    let (term0, goal0) = self.translate_expr(cond);
167                    let (term1, goal1) = self.translate_expr(body);
168                    match term0 {
169                        Term::Var(var) => {
170                            let goal = Goal::And(vec![
171                                goal0,
172                                Goal::Eq(Term::Var(var), Term::Lit(LitVal::Bool(true))),
173                                goal1,
174                                Goal::Eq(Term::Var(x), term1),
175                            ]);
176                            goals.push(goal);
177                        }
178                        Term::Lit(LitVal::Bool(true)) => {
179                            let goal = Goal::And(vec![goal0, goal1, Goal::Eq(Term::Var(x), term1)]);
180                            goals.push(goal);
181                        }
182                        Term::Lit(LitVal::Bool(false)) => {}
183                        _ => {
184                            unreachable!();
185                        }
186                    }
187                }
188                (Term::Var(x), Goal::Or(goals))
189            }
190            ast::Expr::Alter { brchs, span: _ } => {
191                let x = self.fresh_var();
192                let mut goals = Vec::new();
193                for body in brchs {
194                    let (term, goal) = self.translate_expr(body);
195                    let goal = Goal::And(vec![goal, Goal::Eq(Term::Var(x), term)]);
196                    goals.push(goal);
197                }
198                (Term::Var(x), Goal::Or(goals))
199            }
200            ast::Expr::Fresh {
201                vars: new_vars,
202                cont,
203                span: _,
204            } => {
205                let new_vars: Vec<Ident> = new_vars.iter().map(|var| var.ident).collect();
206                let vec: Vec<(Ident, TypeId)> = new_vars
207                    .iter()
208                    .map(|var| (*var, TypeId::Var(Ident::fresh(&"a"))))
209                    .collect();
210                self.vars.extend_from_slice(&vec[..]);
211                self.translate_expr(cont)
212            }
213            ast::Expr::Guard {
214                lhs,
215                rhs,
216                cont,
217                span: _,
218            } => {
219                let (term1, goal1) = self.translate_expr(lhs);
220                let (term2, goal2) =
221                    self.translate_expr(rhs.as_deref().unwrap_or(&Box::new(ast::Expr::Lit {
222                        lit: LitVal::Bool(true),
223                        span: logos::Span { start: 0, end: 0 },
224                    })));
225                let (term3, goal3) = self.translate_expr(cont);
226                (
227                    term3,
228                    Goal::And(vec![goal1, goal2, Goal::Eq(term1, term2), goal3]),
229                )
230            }
231            ast::Expr::Undefined { span: _ } => {
232                (Term::Var(Ident::dummy(&"@placeholder")), Goal::Lit(false))
233            }
234        }
235    }
236
237    fn translate_patn(&mut self, patn: &ast::Pattern) -> TermId {
238        match patn {
239            ast::Pattern::Lit { lit, span: _ } => TermId::Lit(*lit),
240            ast::Pattern::Var { var, span: _ } => {
241                self.vars.push((var.ident, TypeId::Var(Ident::fresh(&"a"))));
242                TermId::Var(var.ident)
243            }
244            ast::Pattern::Cons {
245                cons,
246                flds,
247                span: _,
248            } => {
249                let flds = flds.iter().map(|fld| self.translate_patn(fld)).collect();
250                TermId::Cons(OptCons::Some(cons.ident), flds)
251            }
252            ast::Pattern::Tuple { flds, span: _ } => {
253                let flds: Vec<TermId> = flds.iter().map(|fld| self.translate_patn(fld)).collect();
254                TermId::Cons(OptCons::None, flds)
255            }
256        }
257    }
258}
259
260fn translate_data_decl(data: &ast::DataDecl) -> DataDecl {
261    let name = data.name.ident;
262    let polys = data.polys.iter().map(|poly| poly.ident).collect();
263    let cons = data.cons.iter().map(translate_constructor).collect();
264    DataDecl { name, polys, cons }
265}
266
267fn translate_constructor(cons: &ast::Constructor) -> Constructor {
268    let name = cons.name.ident;
269    let flds = cons.flds.iter().map(translate_type).collect();
270    Constructor { name, flds }
271}
272
273fn translate_query(query: &ast::QueryDecl) -> QueryDecl {
274    QueryDecl {
275        entry: query.entry.ident,
276        params: query
277            .params
278            .iter()
279            .map(|(param, _span)| translate_query_param(param))
280            .collect(),
281    }
282}
283
284fn translate_query_param(param: &ast::QueryParam) -> QueryParam {
285    match param {
286        ast::QueryParam::DepthStep(x) => QueryParam::DepthStep(*x),
287        ast::QueryParam::DepthLimit(x) => QueryParam::DepthLimit(*x),
288        ast::QueryParam::AnswerLimit(x) => QueryParam::AnswerLimit(*x),
289        ast::QueryParam::AnswerPause(x) => QueryParam::AnswerPause(*x),
290    }
291}
292
293fn translate_type(typ: &ast::Type) -> TypeId {
294    match typ {
295        ast::Type::Lit { lit, span: _ } => Term::Lit(*lit),
296        ast::Type::Var { var, span: _ } => Term::Var(var.ident),
297        ast::Type::Cons {
298            cons,
299            flds,
300            span: _,
301        } => {
302            let flds = flds.iter().map(translate_type).collect();
303            Term::Cons(OptCons::Some(cons.ident), flds)
304        }
305        ast::Type::Tuple { flds, span: _ } => {
306            let flds: Vec<TypeId> = flds.iter().map(translate_type).collect();
307            Term::Cons(OptCons::None, flds)
308        }
309    }
310}
311
312pub fn logic_translation(prog: &ast::Program) -> Program {
313    let mut datas: HashMap<Ident, DataDecl> = HashMap::new();
314    for data in prog.datas.iter() {
315        let res = translate_data_decl(data);
316        datas.insert(data.name.ident, res);
317    }
318
319    let mut pass = Translater::new();
320    let mut preds = HashMap::new();
321    for func in prog.funcs.iter() {
322        let res = pass.translate_func(func);
323        preds.insert(func.name.ident, res);
324    }
325
326    let mut querys = Vec::new();
327    for query in prog.querys.iter() {
328        let res = translate_query(query);
329        querys.push(res);
330    }
331    Program {
332        datas,
333        preds,
334        querys,
335    }
336}
337
338#[test]
339#[ignore = "just to see result"]
340fn prog_to_pred_test() {
341    let src: &'static str = r#"
342datatype IntList where
343| Cons(Int, IntList)
344| Nil
345end
346
347function append(xs: IntList, x: Int) -> Int
348begin
349    match xs with
350    | Cons(head, tail) =>
351        Cons(head, append(tail, x))
352    | Nil => Cons(x, Nil)
353    end
354end
355"#;
356    let (prog, errs) = crate::syntax::parser::parse_program(&src);
357    assert!(errs.is_empty());
358
359    let prog = logic_translation(&prog);
360    println!("{:#?}", prog);
361}