jaq_interpret/
mir.rs

1//! Mid-level Intermediate Representation of definitions and filters.
2//!
3//! This mainly analyses occurrences of recursion, which is
4//! important to efficiently execute tail-recursive filters.
5
6use crate::hir::{self, ArgIdx, NativeId, Num, RelId, VarIdx};
7use alloc::{boxed::Box, vec::Vec};
8use jaq_syn::filter::{BinaryOp, Filter as Expr, Fold};
9use jaq_syn::Spanned;
10
11pub type Filter = jaq_syn::filter::Filter<Call, VarIdx, Num>;
12
13pub struct Main {
14    pub defs: Vec<Def>,
15    pub body: Spanned<Filter>,
16}
17
18pub struct Def {
19    pub lhs: jaq_syn::Call,
20    pub rhs: Main,
21    /// is the filter tail-recursive?
22    pub tailrec: bool,
23}
24
25#[derive(Debug, Clone)]
26pub enum Call {
27    Def { id: RelId, skip: usize, tail: bool },
28    Arg(ArgIdx),
29    Native(NativeId),
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum Relative {
34    Parent { tailrec: bool },
35    Sibling { tailrec: Tailrec },
36}
37
38#[derive(Default)]
39pub struct Ctx {
40    /// accessible defined filters
41    callable: Vec<Relative>,
42}
43
44/// which filters can be called tail-recursively at the current point
45pub type Tailrec = alloc::collections::BTreeSet<RelId>;
46
47impl Ctx {
48    pub fn main(&mut self, main: hir::Main, tr: Tailrec) -> Main {
49        for _ in &main.defs {
50            self.callable.push(Relative::Sibling {
51                tailrec: tr.clone(),
52            });
53        }
54        //std::dbg!("handle body", &main.body, &self.callable);
55        let body = self.expr(main.body, tr);
56        //std::dbg!("defs: ", &main.defs);
57        let defs = main.defs.into_iter().rev().map(|def| {
58            //std::dbg!("handle def", &def);
59            let tailrec = match self.callable.pop().unwrap() {
60                Relative::Sibling { tailrec } => tailrec,
61                _ => panic!(),
62            };
63            self.def(def, tailrec)
64        });
65        let mut defs: Vec<_> = defs.collect();
66        defs.reverse();
67        Main { defs, body }
68    }
69
70    pub fn def(&mut self, def: hir::Def, mut tr: Tailrec) -> Def {
71        //std::dbg!("treating def:", &def.lhs);
72        tr.insert(RelId(self.callable.len()));
73        self.callable.push(Relative::Parent { tailrec: false });
74
75        Def {
76            lhs: def.lhs,
77            rhs: self.main(def.rhs, tr),
78            tailrec: match self.callable.pop().unwrap() {
79                Relative::Parent { tailrec } => tailrec,
80                _ => panic!(),
81            },
82        }
83    }
84
85    fn expr(&mut self, f: Spanned<hir::Filter>, tr: Tailrec) -> Spanned<Filter> {
86        // no tail-recursion
87        let notr = Tailrec::default;
88        let get = |ctx: &mut Self, f, tr| Box::new(ctx.expr(f, tr));
89        let result = match f.0 {
90            Expr::Call(call, args) => {
91                let args: Vec<_> = args.into_iter().map(|arg| self.expr(arg, notr())).collect();
92                //std::dbg!(&call);
93                //std::dbg!(&self.callable);
94                let call = match call {
95                    hir::Call::Arg(a) => Call::Arg(a),
96                    hir::Call::Native(n) => Call::Native(n),
97                    hir::Call::Def { id, skip } => {
98                        let tail = match &mut self.callable[id.0] {
99                            Relative::Parent { ref mut tailrec } => {
100                                let tail = tr.contains(&id);
101                                *tailrec = *tailrec || tail;
102                                tail
103                            }
104                            Relative::Sibling { ref mut tailrec } => {
105                                *tailrec = tailrec.intersection(&tr).cloned().collect();
106                                false
107                            }
108                        };
109                        Call::Def { id, skip, tail }
110                    }
111                };
112                Expr::Call(call, args)
113            }
114            Expr::Var(v) => Expr::Var(v),
115            Expr::Binary(l, BinaryOp::Comma, r) => {
116                let l = get(self, *l, tr.clone());
117                let r = get(self, *r, tr);
118                Expr::Binary(l, BinaryOp::Comma, r)
119            }
120            Expr::Binary(l, op @ (BinaryOp::Alt | BinaryOp::Pipe(_)), r) => {
121                let l = get(self, *l, notr());
122                let r = get(self, *r, tr);
123                Expr::Binary(l, op, r)
124            }
125            Expr::Binary(l, op, r) => {
126                Expr::Binary(get(self, *l, notr()), op, get(self, *r, notr()))
127            }
128
129            Expr::Fold(typ, Fold { xs, x, init, f }) => {
130                let xs = get(self, *xs, notr());
131                let init = get(self, *init, notr());
132                let f = get(self, *f, notr());
133                Expr::Fold(typ, Fold { xs, x, init, f })
134            }
135            Expr::Id => Expr::Id,
136            Expr::Recurse => Expr::Recurse,
137            Expr::Num(n) => Expr::Num(n),
138            Expr::Str(s) => Expr::Str(Box::new((*s).map(|f| self.expr(f, notr())))),
139            Expr::Array(a) => Expr::Array(a.map(|a| get(self, *a, notr()))),
140            Expr::Object(o) => Expr::Object(
141                o.into_iter()
142                    .map(|kv| kv.map(|f| self.expr(f, notr())))
143                    .collect(),
144            ),
145            Expr::Try(f) => Expr::Try(get(self, *f, notr())),
146            Expr::Neg(f) => Expr::Neg(get(self, *f, notr())),
147
148            Expr::Ite(if_thens, else_) => {
149                let if_thens = if_thens
150                    .into_iter()
151                    .map(|(i, t)| (self.expr(i, notr()), self.expr(t, tr.clone())));
152                Expr::Ite(if_thens.collect(), else_.map(|else_| get(self, *else_, tr)))
153            }
154            Expr::TryCatch(try_, catch_) => {
155                Expr::TryCatch(get(self, *try_, notr()), catch_.map(|c| get(self, *c, tr)))
156            }
157            Expr::Path(f, path) => {
158                let f = get(self, *f, notr());
159                let path = path
160                    .into_iter()
161                    .map(|(p, opt)| (p.map(|p| self.expr(p, notr())), opt));
162                Expr::Path(f, path.collect())
163            }
164        };
165        (result, f.1)
166    }
167}