1use crate::hir::{self, ArgIdx, NativeId, Num, RelId, VarIdx};
7use alloc::{boxed::Box, vec::Vec};
8use jaq_syn::filter::{BinaryOp, Filter as Expr, Fold};
9use jaq_syn::Spanned;
10
11pub type Filter = jaq_syn::filter::Filter<Call, VarIdx, Num>;
12
13pub struct Main {
14 pub defs: Vec<Def>,
15 pub body: Spanned<Filter>,
16}
17
18pub struct Def {
19 pub lhs: jaq_syn::Call,
20 pub rhs: Main,
21 pub tailrec: bool,
23}
24
25#[derive(Debug, Clone)]
26pub enum Call {
27 Def { id: RelId, skip: usize, tail: bool },
28 Arg(ArgIdx),
29 Native(NativeId),
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum Relative {
34 Parent { tailrec: bool },
35 Sibling { tailrec: Tailrec },
36}
37
38#[derive(Default)]
39pub struct Ctx {
40 callable: Vec<Relative>,
42}
43
44pub type Tailrec = alloc::collections::BTreeSet<RelId>;
46
47impl Ctx {
48 pub fn main(&mut self, main: hir::Main, tr: Tailrec) -> Main {
49 for _ in &main.defs {
50 self.callable.push(Relative::Sibling {
51 tailrec: tr.clone(),
52 });
53 }
54 let body = self.expr(main.body, tr);
56 let defs = main.defs.into_iter().rev().map(|def| {
58 let tailrec = match self.callable.pop().unwrap() {
60 Relative::Sibling { tailrec } => tailrec,
61 _ => panic!(),
62 };
63 self.def(def, tailrec)
64 });
65 let mut defs: Vec<_> = defs.collect();
66 defs.reverse();
67 Main { defs, body }
68 }
69
70 pub fn def(&mut self, def: hir::Def, mut tr: Tailrec) -> Def {
71 tr.insert(RelId(self.callable.len()));
73 self.callable.push(Relative::Parent { tailrec: false });
74
75 Def {
76 lhs: def.lhs,
77 rhs: self.main(def.rhs, tr),
78 tailrec: match self.callable.pop().unwrap() {
79 Relative::Parent { tailrec } => tailrec,
80 _ => panic!(),
81 },
82 }
83 }
84
85 fn expr(&mut self, f: Spanned<hir::Filter>, tr: Tailrec) -> Spanned<Filter> {
86 let notr = Tailrec::default;
88 let get = |ctx: &mut Self, f, tr| Box::new(ctx.expr(f, tr));
89 let result = match f.0 {
90 Expr::Call(call, args) => {
91 let args: Vec<_> = args.into_iter().map(|arg| self.expr(arg, notr())).collect();
92 let call = match call {
95 hir::Call::Arg(a) => Call::Arg(a),
96 hir::Call::Native(n) => Call::Native(n),
97 hir::Call::Def { id, skip } => {
98 let tail = match &mut self.callable[id.0] {
99 Relative::Parent { ref mut tailrec } => {
100 let tail = tr.contains(&id);
101 *tailrec = *tailrec || tail;
102 tail
103 }
104 Relative::Sibling { ref mut tailrec } => {
105 *tailrec = tailrec.intersection(&tr).cloned().collect();
106 false
107 }
108 };
109 Call::Def { id, skip, tail }
110 }
111 };
112 Expr::Call(call, args)
113 }
114 Expr::Var(v) => Expr::Var(v),
115 Expr::Binary(l, BinaryOp::Comma, r) => {
116 let l = get(self, *l, tr.clone());
117 let r = get(self, *r, tr);
118 Expr::Binary(l, BinaryOp::Comma, r)
119 }
120 Expr::Binary(l, op @ (BinaryOp::Alt | BinaryOp::Pipe(_)), r) => {
121 let l = get(self, *l, notr());
122 let r = get(self, *r, tr);
123 Expr::Binary(l, op, r)
124 }
125 Expr::Binary(l, op, r) => {
126 Expr::Binary(get(self, *l, notr()), op, get(self, *r, notr()))
127 }
128
129 Expr::Fold(typ, Fold { xs, x, init, f }) => {
130 let xs = get(self, *xs, notr());
131 let init = get(self, *init, notr());
132 let f = get(self, *f, notr());
133 Expr::Fold(typ, Fold { xs, x, init, f })
134 }
135 Expr::Id => Expr::Id,
136 Expr::Recurse => Expr::Recurse,
137 Expr::Num(n) => Expr::Num(n),
138 Expr::Str(s) => Expr::Str(Box::new((*s).map(|f| self.expr(f, notr())))),
139 Expr::Array(a) => Expr::Array(a.map(|a| get(self, *a, notr()))),
140 Expr::Object(o) => Expr::Object(
141 o.into_iter()
142 .map(|kv| kv.map(|f| self.expr(f, notr())))
143 .collect(),
144 ),
145 Expr::Try(f) => Expr::Try(get(self, *f, notr())),
146 Expr::Neg(f) => Expr::Neg(get(self, *f, notr())),
147
148 Expr::Ite(if_thens, else_) => {
149 let if_thens = if_thens
150 .into_iter()
151 .map(|(i, t)| (self.expr(i, notr()), self.expr(t, tr.clone())));
152 Expr::Ite(if_thens.collect(), else_.map(|else_| get(self, *else_, tr)))
153 }
154 Expr::TryCatch(try_, catch_) => {
155 Expr::TryCatch(get(self, *try_, notr()), catch_.map(|c| get(self, *c, tr)))
156 }
157 Expr::Path(f, path) => {
158 let f = get(self, *f, notr());
159 let path = path
160 .into_iter()
161 .map(|(p, opt)| (p.map(|p| self.expr(p, notr())), opt));
162 Expr::Path(f, path.collect())
163 }
164 };
165 (result, f.1)
166 }
167}