1use anyhow::{Result, anyhow};
16use fxhash::FxHashSet;
17use mangle_ir::physical::{self, Aggregate, CmpOp, Condition, DataSource, Expr, Op, Operand};
18use mangle_ir::{Inst, InstId, Ir, NameId};
19
20pub struct Planner<'a> {
21 ir: &'a mut Ir,
22 delta_pred: Option<NameId>,
23 fresh_counter: usize,
24}
25
26impl<'a> Planner<'a> {
27 pub fn new(ir: &'a mut Ir) -> Self {
28 Self {
29 ir,
30 delta_pred: None,
31 fresh_counter: 0,
32 }
33 }
34
35 pub fn with_delta(mut self, delta_pred: NameId) -> Self {
36 self.delta_pred = Some(delta_pred);
37 self
38 }
39
40 pub fn plan_rule(mut self, rule_id: InstId) -> Result<Op> {
41 let (head, premises, transform) = match self.ir.get(rule_id) {
42 Inst::Rule {
43 head,
44 premises,
45 transform,
46 } => (*head, premises.clone(), transform.clone()),
47 _ => return Err(anyhow!("Not a rule")),
48 };
49
50 let blocks = self.split_transforms(transform);
52 let num_blocks = blocks.len();
53
54 let mut ops = Vec::new();
55 let mut current_source: Option<(NameId, Vec<NameId>)> = None;
56 let mut bound_vars = FxHashSet::default();
57
58 for (i, block) in blocks.into_iter().enumerate() {
59 let is_last = i == num_blocks - 1;
60
61 if i == 0 {
62 if is_last {
64 let op = self.plan_join_sequence(
66 premises.clone(),
67 &mut bound_vars,
68 |planner, vars| {
69 planner.plan_transforms_sequence(&block, vars, |p, v| {
70 p.plan_head_insert(head, v)
71 })
72 },
73 )?;
74 ops.push(op);
75 } else {
76 let temp_rel = self.fresh_var("temp_grp");
78 let mut capture_vars: Vec<NameId> = Vec::new(); let op = self.plan_join_sequence(
81 premises.clone(),
82 &mut bound_vars,
83 |planner, vars| {
84 planner.plan_transforms_sequence(&block, vars, |_, v| {
85 let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
86 sorted_vars.sort();
87 capture_vars = sorted_vars.clone();
88 let args =
89 sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
90 Ok(Op::Insert {
91 relation: temp_rel,
92 args,
93 })
94 })
95 },
96 )?;
97 ops.push(op);
98 current_source = Some((temp_rel, capture_vars));
99 }
100 } else {
101 let (src_rel, src_vars) = current_source.take().expect("No source for aggregation");
103
104 if is_last {
105 let op = self.plan_block_k(src_rel, src_vars, &block, |p, v| {
106 p.plan_head_insert(head, v)
107 })?;
108 ops.push(op);
109 } else {
110 let next_temp = self.fresh_var("temp_grp");
111 let mut next_vars: Vec<NameId> = Vec::new();
112
113 let op = self.plan_block_k(src_rel, src_vars, &block, |_, v| {
114 let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
115 sorted_vars.sort();
116 next_vars = sorted_vars.clone();
117 let args = sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
118 Ok(Op::Insert {
119 relation: next_temp,
120 args,
121 })
122 })?;
123 ops.push(op);
124 current_source = Some((next_temp, next_vars));
125 }
126 }
127 }
128
129 if ops.len() == 1 {
130 Ok(ops.remove(0))
131 } else {
132 Ok(Op::Seq(ops))
133 }
134 }
135
136 fn split_transforms(&self, transforms: Vec<InstId>) -> Vec<Vec<InstId>> {
137 let mut blocks = Vec::new();
138 let mut current = Vec::new();
139 for t in transforms {
140 let inst = self.ir.get(t);
141 if let Inst::Transform { var: None, .. } = inst {
142 blocks.push(current);
143 current = Vec::new();
144 }
145 current.push(t);
146 }
147 blocks.push(current);
148 blocks
149 }
150
151 fn plan_block_k<F>(
152 &mut self,
153 source_rel: NameId,
154 source_vars: Vec<NameId>,
155 block: &[InstId],
156 continuation: F,
157 ) -> Result<Op>
158 where
159 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
160 {
161 let do_stmt = block[0];
162 let rest = &block[1..];
163
164 let keys_insts = self.get_transform_app_args(do_stmt)?;
165 let mut keys = Vec::new();
166 for k in keys_insts {
167 if let Inst::Var(v) = self.ir.get(k) {
168 keys.push(*v);
169 } else {
170 return Err(anyhow!("GroupBy keys must be variables"));
171 }
172 }
173
174 let mut aggregates = Vec::new();
175 let mut lets = Vec::new();
176 for &t in rest {
177 if let Some(agg) = self.try_parse_aggregate(t)? {
178 aggregates.push(agg);
179 } else {
180 lets.push(t);
181 }
182 }
183
184 let mut inner_vars = FxHashSet::default();
185 for &k in &keys {
186 inner_vars.insert(k);
187 }
188 for agg in &aggregates {
189 inner_vars.insert(agg.var);
190 }
191
192 let body = self.plan_transforms_sequence(&lets, &mut inner_vars, continuation)?;
193
194 Ok(Op::GroupBy {
195 source: source_rel,
196 vars: source_vars,
197 keys,
198 aggregates,
199 body: Box::new(body),
200 })
201 }
202
203 fn plan_transforms_sequence<F>(
204 &mut self,
205 transforms: &[InstId],
206 bound_vars: &mut FxHashSet<NameId>,
207 continuation: F,
208 ) -> Result<Op>
209 where
210 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
211 {
212 if transforms.is_empty() {
213 return continuation(self, bound_vars);
214 }
215
216 let t_id = transforms[0];
217 let rest = &transforms[1..];
218
219 let inst = self.ir.get(t_id).clone();
220 if let Inst::Transform {
221 var: Some(var),
222 app,
223 } = inst
224 {
225 self.inst_to_expr(app, |planner, expr| {
226 bound_vars.insert(var);
227 let body = planner.plan_transforms_sequence(rest, bound_vars, continuation)?;
228 Ok(Op::Let {
229 var,
230 expr,
231 body: Box::new(body),
232 })
233 })
234 } else {
235 Err(anyhow!("Unexpected transform in sequence"))
237 }
238 }
239
240 fn fresh_var(&mut self, prefix: &str) -> NameId {
241 let id = self.fresh_counter;
242 self.fresh_counter += 1;
243 let name = format!("${}_{}", prefix, id);
244 self.ir.intern_name(name)
245 }
246
247 fn plan_join_sequence<F>(
248 &mut self,
249 mut premises: Vec<InstId>,
250 bound_vars: &mut FxHashSet<NameId>,
251 continuation: F,
252 ) -> Result<Op>
253 where
254 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
255 {
256 if premises.is_empty() {
257 return continuation(self, bound_vars);
258 }
259
260 let current_premise = premises.remove(0);
261 let inst = self.ir.get(current_premise).clone();
262
263 match inst {
264 Inst::Atom { predicate, args } => {
265 let mut scan_vars = Vec::new();
266 let mut new_bindings = Vec::new();
267
268 let mut index_lookup: Option<(usize, Operand)> = None;
270
271 for (i, arg) in args.iter().enumerate() {
272 let arg_inst = self.ir.get(*arg).clone();
273 match arg_inst {
274 Inst::Var(v) if bound_vars.contains(&v) => {
275 if index_lookup.is_none() {
276 index_lookup = Some((i, Operand::Var(v)));
277 }
278 }
279 Inst::Number(n) => {
280 if index_lookup.is_none() {
281 index_lookup =
282 Some((i, Operand::Const(physical::Constant::Number(n))));
283 }
284 }
285 Inst::String(s) => {
286 if index_lookup.is_none() {
287 index_lookup =
288 Some((i, Operand::Const(physical::Constant::String(s))));
289 }
290 }
291 Inst::Name(n) => {
292 if index_lookup.is_none() {
293 index_lookup =
294 Some((i, Operand::Const(physical::Constant::Name(n))));
295 }
296 }
297 _ => {}
298 }
299 }
300
301 for arg in &args {
302 if let Inst::Var(v) = self.ir.get(*arg)
303 && !bound_vars.contains(v)
304 {
305 scan_vars.push(*v);
306 new_bindings.push(*v);
307 continue;
308 }
309 let tmp = self.fresh_var("scan");
310 scan_vars.push(tmp);
311 new_bindings.push(tmp);
312 }
313
314 for v in &new_bindings {
315 bound_vars.insert(*v);
316 }
317
318 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
319 let wrapped_body = self.apply_constraints(&args, &scan_vars, body)?;
320
321 let source = if let Some((col_idx, key)) = index_lookup {
322 DataSource::IndexLookup {
323 relation: predicate,
324 col_idx,
325 key,
326 vars: scan_vars,
327 }
328 } else if Some(predicate) == self.delta_pred {
329 DataSource::ScanDelta {
330 relation: predicate,
331 vars: scan_vars,
332 }
333 } else {
334 DataSource::Scan {
335 relation: predicate,
336 vars: scan_vars,
337 }
338 };
339 Ok(Op::Iterate {
340 source,
341 body: Box::new(wrapped_body),
342 })
343 }
344 Inst::Eq(l, r) => {
345 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
346 self.wrap_eq_check(l, r, body)
347 }
348 Inst::Ineq(l, r) => {
349 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
350 self.with_eval(l, |this, left_op| {
351 this.with_eval(r, |_this, right_op| {
352 Ok(Op::Filter {
353 cond: Condition::Cmp {
354 op: CmpOp::Neq,
355 left: left_op.clone(),
356 right: right_op,
357 },
358 body: Box::new(body),
359 })
360 })
361 })
362 }
363 Inst::NegAtom(inner) => {
364 let inner_inst = self.ir.get(inner).clone();
365 if let Inst::Atom { predicate, args } = inner_inst {
366 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
367 let mut neg_args = Vec::new();
368 for arg in &args {
369 let arg_inst = self.ir.get(*arg).clone();
370 match arg_inst {
371 Inst::Var(v) => neg_args.push(Operand::Var(v)),
372 Inst::Number(n) => {
373 neg_args.push(Operand::Const(physical::Constant::Number(n)))
374 }
375 Inst::String(s) => {
376 neg_args.push(Operand::Const(physical::Constant::String(s)))
377 }
378 Inst::Name(n) => {
379 neg_args.push(Operand::Const(physical::Constant::Name(n)))
380 }
381 _ => return Err(anyhow!("Complex expression in negated atom")),
382 }
383 }
384 Ok(Op::Filter {
385 cond: Condition::Negation {
386 relation: predicate,
387 args: neg_args,
388 },
389 body: Box::new(body),
390 })
391 } else {
392 Err(anyhow!("NegAtom wraps non-atom"))
393 }
394 }
395 _ => Err(anyhow!("Unsupported premise type: {:?}", inst)),
396 }
397 }
398
399 fn apply_constraints(
400 &mut self,
401 args: &[InstId],
402 scan_vars: &[NameId],
403 mut body: Op,
404 ) -> Result<Op> {
405 for (i, arg) in args.iter().enumerate().rev() {
406 let scan_var = scan_vars[i];
407 let arg_inst = self.ir.get(*arg).clone();
408 match arg_inst {
409 Inst::Var(v) => {
410 if v == scan_var {
411 continue;
412 }
413 body = Op::Filter {
414 cond: Condition::Cmp {
415 op: CmpOp::Eq,
416 left: Operand::Var(scan_var),
417 right: Operand::Var(v),
418 },
419 body: Box::new(body),
420 };
421 }
422 _ => {
423 body = self.wrap_eval_check(*arg, Operand::Var(scan_var), body)?;
424 }
425 }
426 }
427 Ok(body)
428 }
429
430 fn wrap_eq_check(&mut self, l: InstId, r: InstId, body: Op) -> Result<Op> {
431 self.with_eval(l, |this, op_l| {
432 this.with_eval(r, |_this, op_r| {
433 Ok(Op::Filter {
434 cond: Condition::Cmp {
435 op: CmpOp::Eq,
436 left: op_l,
437 right: op_r,
438 },
439 body: Box::new(body),
440 })
441 })
442 })
443 }
444
445 fn wrap_eval_check(&mut self, inst: InstId, target: Operand, body: Op) -> Result<Op> {
446 self.with_eval(inst, |_this, op| {
447 Ok(Op::Filter {
448 cond: Condition::Cmp {
449 op: CmpOp::Eq,
450 left: target,
451 right: op,
452 },
453 body: Box::new(body),
454 })
455 })
456 }
457
458 fn with_eval<F>(&mut self, inst: InstId, f: F) -> Result<Op>
459 where
460 F: FnOnce(&mut Self, Operand) -> Result<Op>,
461 {
462 let i = self.ir.get(inst).clone();
463 match i {
464 Inst::Var(v) => f(self, Operand::Var(v)),
465 Inst::String(s) => f(self, Operand::Const(physical::Constant::String(s))),
466 Inst::Number(n) => f(self, Operand::Const(physical::Constant::Number(n))),
467 Inst::Name(n) => f(self, Operand::Const(physical::Constant::Name(n))),
468 Inst::ApplyFn { function, args } => self.with_eval_args(
469 &args,
470 0,
471 Vec::new(),
472 Box::new(|this, ops| {
473 let tmp = this.fresh_var("call");
474 let inner = f(this, Operand::Var(tmp))?;
475 Ok(Op::Let {
476 var: tmp,
477 expr: Expr::Call {
478 function,
479 args: ops,
480 },
481 body: Box::new(inner),
482 })
483 }),
484 ),
485 _ => Err(anyhow!("Unsupported expression in evaluation")),
486 }
487 }
488
489 fn inst_to_expr<F>(&mut self, inst: InstId, f: F) -> Result<Op>
490 where
491 F: FnOnce(&mut Self, Expr) -> Result<Op>,
492 {
493 let i = self.ir.get(inst).clone();
494 match i {
495 Inst::ApplyFn { function, args } => self.with_eval_args(
496 &args,
497 0,
498 Vec::new(),
499 Box::new(|this, ops| {
500 f(
501 this,
502 Expr::Call {
503 function,
504 args: ops,
505 },
506 )
507 }),
508 ),
509 _ => self.with_eval(inst, |this, op| f(this, Expr::Value(op))),
510 }
511 }
512
513 fn with_eval_args(
514 &mut self,
515 args: &[InstId],
516 index: usize,
517 mut acc: Vec<Operand>,
518 f: Box<dyn FnOnce(&mut Self, Vec<Operand>) -> Result<Op> + '_>,
519 ) -> Result<Op> {
520 if index >= args.len() {
521 return f(self, acc);
522 }
523 self.with_eval(args[index], |this, op| {
524 acc.push(op);
525 this.with_eval_args(args, index + 1, acc, f)
526 })
527 }
528
529 fn plan_head_insert(
530 &mut self,
531 head: InstId,
532 _bound_vars: &mut FxHashSet<NameId>,
533 ) -> Result<Op> {
534 let inst = self.ir.get(head).clone();
535 if let Inst::Atom { predicate, args } = inst {
536 self.with_eval_args(
537 &args,
538 0,
539 Vec::new(),
540 Box::new(|_this, ops| {
541 Ok(Op::Insert {
542 relation: predicate,
543 args: ops,
544 })
545 }),
546 )
547 } else {
548 Err(anyhow!("Head must be an atom"))
549 }
550 }
551
552 fn get_transform_app_args(&self, t_id: InstId) -> Result<Vec<InstId>> {
553 if let Inst::Transform { app, .. } = self.ir.get(t_id)
554 && let Inst::ApplyFn { args, .. } = self.ir.get(*app)
555 {
556 return Ok(args.clone());
557 }
558 Err(anyhow!("Invalid transform structure"))
559 }
560
561 fn try_parse_aggregate(&mut self, t_id: InstId) -> Result<Option<Aggregate>> {
562 let inst = self.ir.get(t_id).clone();
563 if let Inst::Transform {
564 var: Some(var),
565 app,
566 } = inst
567 && let Inst::ApplyFn { function, args } = self.ir.get(app).clone()
568 {
569 let func_name = self.ir.resolve_name(function);
570 if matches!(
571 func_name,
572 "fn:sum" | "fn:count" | "fn:max" | "fn:min" | "fn:collect"
573 ) {
574 let mut op_args = Vec::new();
575 for arg in args {
576 let arg_inst = self.ir.get(arg).clone();
577 match arg_inst {
578 Inst::Var(v) => op_args.push(Operand::Var(v)),
579 Inst::Number(n) => {
580 op_args.push(Operand::Const(physical::Constant::Number(n)))
581 }
582 _ => {
583 return Err(anyhow!(
584 "Complex expressions in aggregates not supported yet"
585 ));
586 }
587 }
588 }
589 return Ok(Some(Aggregate {
590 var,
591 func: function,
592 args: op_args,
593 }));
594 }
595 }
596 Ok(None)
597 }
598}