1use anyhow::{Result, anyhow};
16use fxhash::FxHashSet;
17use mangle_ir::physical::{self, Aggregate, CmpOp, Condition, DataSource, Expr, Op, Operand};
18use mangle_ir::{Inst, InstId, Ir, NameId};
19
20pub struct Planner<'a> {
21 ir: &'a mut Ir,
22 delta_pred: Option<NameId>,
23 fresh_counter: usize,
24}
25
26impl<'a> Planner<'a> {
27 pub fn new(ir: &'a mut Ir) -> Self {
28 Self {
29 ir,
30 delta_pred: None,
31 fresh_counter: 0,
32 }
33 }
34
35 pub fn with_delta(mut self, delta_pred: NameId) -> Self {
36 self.delta_pred = Some(delta_pred);
37 self
38 }
39
40 pub fn plan_rule(mut self, rule_id: InstId) -> Result<Op> {
41 let (head, premises, transform) = match self.ir.get(rule_id) {
42 Inst::Rule {
43 head,
44 premises,
45 transform,
46 } => (*head, premises.clone(), transform.clone()),
47 _ => return Err(anyhow!("Not a rule")),
48 };
49
50 let blocks = self.split_transforms(transform);
52 let num_blocks = blocks.len();
53
54 let mut ops = Vec::new();
55 let mut current_source: Option<(NameId, Vec<NameId>)> = None;
56 let mut bound_vars = FxHashSet::default();
57
58 for (i, block) in blocks.into_iter().enumerate() {
59 let is_last = i == num_blocks - 1;
60
61 if i == 0 {
62 if is_last {
64 let op = self.plan_join_sequence(
66 premises.clone(),
67 &mut bound_vars,
68 |planner, vars| {
69 planner.plan_transforms_sequence(&block, vars, |p, v| {
70 p.plan_head_insert(head, v)
71 })
72 },
73 )?;
74 ops.push(op);
75 } else {
76 let temp_rel = self.fresh_var("temp_grp");
78 let mut capture_vars: Vec<NameId> = Vec::new(); let op = self.plan_join_sequence(
81 premises.clone(),
82 &mut bound_vars,
83 |planner, vars| {
84 planner.plan_transforms_sequence(&block, vars, |_, v| {
85 let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
86 sorted_vars.sort();
87 capture_vars = sorted_vars.clone();
88 let args =
89 sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
90 Ok(Op::Insert {
91 relation: temp_rel,
92 args,
93 })
94 })
95 },
96 )?;
97 ops.push(op);
98 current_source = Some((temp_rel, capture_vars));
99 }
100 } else {
101 let (src_rel, src_vars) = current_source.take().expect("No source for aggregation");
103
104 if is_last {
105 let op = self.plan_block_k(src_rel, src_vars, &block, |p, v| {
106 p.plan_head_insert(head, v)
107 })?;
108 ops.push(op);
109 } else {
110 let next_temp = self.fresh_var("temp_grp");
111 let mut next_vars: Vec<NameId> = Vec::new();
112
113 let op = self.plan_block_k(src_rel, src_vars, &block, |_, v| {
114 let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
115 sorted_vars.sort();
116 next_vars = sorted_vars.clone();
117 let args = sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
118 Ok(Op::Insert {
119 relation: next_temp,
120 args,
121 })
122 })?;
123 ops.push(op);
124 current_source = Some((next_temp, next_vars));
125 }
126 }
127 }
128
129 if ops.len() == 1 {
130 Ok(ops.remove(0))
131 } else {
132 Ok(Op::Seq(ops))
133 }
134 }
135
136 fn split_transforms(&self, transforms: Vec<InstId>) -> Vec<Vec<InstId>> {
137 let mut blocks = Vec::new();
138 let mut current = Vec::new();
139 for t in transforms {
140 let inst = self.ir.get(t);
141 if let Inst::Transform { var: None, .. } = inst {
142 blocks.push(current);
143 current = Vec::new();
144 }
145 current.push(t);
146 }
147 blocks.push(current);
148 blocks
149 }
150
151 fn plan_block_k<F>(
152 &mut self,
153 source_rel: NameId,
154 source_vars: Vec<NameId>,
155 block: &[InstId],
156 continuation: F,
157 ) -> Result<Op>
158 where
159 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
160 {
161 let do_stmt = block[0];
162 let rest = &block[1..];
163
164 let keys_insts = self.get_transform_app_args(do_stmt)?;
165 let mut keys = Vec::new();
166 for k in keys_insts {
167 if let Inst::Var(v) = self.ir.get(k) {
168 keys.push(*v);
169 } else {
170 return Err(anyhow!("GroupBy keys must be variables"));
171 }
172 }
173
174 let mut aggregates = Vec::new();
175 let mut lets = Vec::new();
176 for &t in rest {
177 if let Some(agg) = self.try_parse_aggregate(t)? {
178 aggregates.push(agg);
179 } else {
180 lets.push(t);
181 }
182 }
183
184 let mut inner_vars = FxHashSet::default();
185 for &k in &keys {
186 inner_vars.insert(k);
187 }
188 for agg in &aggregates {
189 inner_vars.insert(agg.var);
190 }
191
192 let body = self.plan_transforms_sequence(&lets, &mut inner_vars, continuation)?;
193
194 Ok(Op::GroupBy {
195 source: source_rel,
196 vars: source_vars,
197 keys,
198 aggregates,
199 body: Box::new(body),
200 })
201 }
202
203 fn plan_transforms_sequence<F>(
204 &mut self,
205 transforms: &[InstId],
206 bound_vars: &mut FxHashSet<NameId>,
207 continuation: F,
208 ) -> Result<Op>
209 where
210 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
211 {
212 if transforms.is_empty() {
213 return continuation(self, bound_vars);
214 }
215
216 let t_id = transforms[0];
217 let rest = &transforms[1..];
218
219 let inst = self.ir.get(t_id).clone();
220 if let Inst::Transform {
221 var: Some(var),
222 app,
223 } = inst
224 {
225 self.inst_to_expr(app, |planner, expr| {
226 bound_vars.insert(var);
227 let body = planner.plan_transforms_sequence(rest, bound_vars, continuation)?;
228 Ok(Op::Let {
229 var,
230 expr,
231 body: Box::new(body),
232 })
233 })
234 } else {
235 Err(anyhow!("Unexpected transform in sequence"))
237 }
238 }
239
240 fn fresh_var(&mut self, prefix: &str) -> NameId {
241 let id = self.fresh_counter;
242 self.fresh_counter += 1;
243 let name = format!("${}_{}", prefix, id);
244 self.ir.intern_name(name)
245 }
246
247 fn plan_join_sequence<F>(
248 &mut self,
249 mut premises: Vec<InstId>,
250 bound_vars: &mut FxHashSet<NameId>,
251 continuation: F,
252 ) -> Result<Op>
253 where
254 F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
255 {
256 if premises.is_empty() {
257 return continuation(self, bound_vars);
258 }
259
260 let current_premise = premises.remove(0);
261 let inst = self.ir.get(current_premise).clone();
262
263 match inst {
264 Inst::Atom { predicate, args }
265 if matches!(
266 self.ir.resolve_name(predicate),
267 ":lt" | ":le" | ":gt" | ":ge"
268 | ":time:lt" | ":time:le" | ":time:gt" | ":time:ge"
269 | ":duration:lt" | ":duration:le" | ":duration:gt" | ":duration:ge"
270 ) =>
271 {
272 let cmp_op = match self.ir.resolve_name(predicate) {
273 ":lt" | ":time:lt" | ":duration:lt" => CmpOp::Lt,
274 ":le" | ":time:le" | ":duration:le" => CmpOp::Le,
275 ":gt" | ":time:gt" | ":duration:gt" => CmpOp::Gt,
276 ":ge" | ":time:ge" | ":duration:ge" => CmpOp::Ge,
277 _ => unreachable!(),
278 };
279 if args.len() != 2 {
280 return Err(anyhow!("Comparison predicate requires exactly 2 arguments"));
281 }
282 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
283 self.with_eval(args[0], |this, left_op| {
284 this.with_eval(args[1], |_this, right_op| {
285 Ok(Op::Filter {
286 cond: Condition::Cmp {
287 op: cmp_op,
288 left: left_op.clone(),
289 right: right_op,
290 },
291 body: Box::new(body),
292 })
293 })
294 })
295 }
296 Inst::Atom { predicate, args }
297 if matches!(
298 self.ir.resolve_name(predicate),
299 ":string:starts_with"
300 | ":string:ends_with"
301 | ":string:contains"
302 | ":match_prefix"
303 ) =>
304 {
305 if args.len() != 2 {
306 return Err(anyhow!(
307 "Built-in predicate requires exactly 2 arguments"
308 ));
309 }
310 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
311 self.with_eval(args[0], |this, left_op| {
312 this.with_eval(args[1], |_this, right_op| {
313 Ok(Op::Filter {
314 cond: Condition::Call {
315 function: predicate,
316 args: vec![left_op.clone(), right_op],
317 },
318 body: Box::new(body),
319 })
320 })
321 })
322 }
323 Inst::Atom { predicate, args } => {
324 let mut scan_vars = Vec::new();
325 let mut new_bindings = Vec::new();
326
327 let mut index_lookup: Option<(usize, Operand)> = None;
329
330 for (i, arg) in args.iter().enumerate() {
331 let arg_inst = self.ir.get(*arg).clone();
332 match arg_inst {
333 Inst::Var(v) if bound_vars.contains(&v) => {
334 if index_lookup.is_none() {
335 index_lookup = Some((i, Operand::Var(v)));
336 }
337 }
338 Inst::Number(n) => {
339 if index_lookup.is_none() {
340 index_lookup =
341 Some((i, Operand::Const(physical::Constant::Number(n))));
342 }
343 }
344 Inst::String(s) => {
345 if index_lookup.is_none() {
346 index_lookup =
347 Some((i, Operand::Const(physical::Constant::String(s))));
348 }
349 }
350 Inst::Name(n) => {
351 if index_lookup.is_none() {
352 index_lookup =
353 Some((i, Operand::Const(physical::Constant::Name(n))));
354 }
355 }
356 Inst::Float(fl) => {
357 if index_lookup.is_none() {
358 index_lookup =
359 Some((i, Operand::Const(physical::Constant::Float(fl))));
360 }
361 }
362 Inst::Time(t) => {
363 if index_lookup.is_none() {
364 index_lookup =
365 Some((i, Operand::Const(physical::Constant::Time(t))));
366 }
367 }
368 Inst::Duration(d) => {
369 if index_lookup.is_none() {
370 index_lookup =
371 Some((i, Operand::Const(physical::Constant::Duration(d))));
372 }
373 }
374 _ => {}
375 }
376 }
377
378 for arg in &args {
379 if let Inst::Var(v) = self.ir.get(*arg)
380 && !bound_vars.contains(v)
381 {
382 scan_vars.push(*v);
383 new_bindings.push(*v);
384 continue;
385 }
386 let tmp = self.fresh_var("scan");
387 scan_vars.push(tmp);
388 new_bindings.push(tmp);
389 }
390
391 for v in &new_bindings {
392 bound_vars.insert(*v);
393 }
394
395 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
396 let wrapped_body = self.apply_constraints(&args, &scan_vars, body)?;
397
398 let source = if let Some((col_idx, key)) = index_lookup {
399 DataSource::IndexLookup {
400 relation: predicate,
401 col_idx,
402 key,
403 vars: scan_vars,
404 }
405 } else if Some(predicate) == self.delta_pred {
406 DataSource::ScanDelta {
407 relation: predicate,
408 vars: scan_vars,
409 }
410 } else {
411 DataSource::Scan {
412 relation: predicate,
413 vars: scan_vars,
414 }
415 };
416 Ok(Op::Iterate {
417 source,
418 body: Box::new(wrapped_body),
419 })
420 }
421 Inst::Eq(l, r) => {
422 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
423 self.wrap_eq_check(l, r, body)
424 }
425 Inst::Ineq(l, r) => {
426 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
427 self.with_eval(l, |this, left_op| {
428 this.with_eval(r, |_this, right_op| {
429 Ok(Op::Filter {
430 cond: Condition::Cmp {
431 op: CmpOp::Neq,
432 left: left_op.clone(),
433 right: right_op,
434 },
435 body: Box::new(body),
436 })
437 })
438 })
439 }
440 Inst::NegAtom(inner) => {
441 let inner_inst = self.ir.get(inner).clone();
442 if let Inst::Atom { predicate, args } = inner_inst {
443 let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
444 let mut neg_args = Vec::new();
445 for arg in &args {
446 let arg_inst = self.ir.get(*arg).clone();
447 match arg_inst {
448 Inst::Var(v) => neg_args.push(Operand::Var(v)),
449 Inst::Number(n) => {
450 neg_args.push(Operand::Const(physical::Constant::Number(n)))
451 }
452 Inst::String(s) => {
453 neg_args.push(Operand::Const(physical::Constant::String(s)))
454 }
455 Inst::Name(n) => {
456 neg_args.push(Operand::Const(physical::Constant::Name(n)))
457 }
458 Inst::Float(fl) => {
459 neg_args.push(Operand::Const(physical::Constant::Float(fl)))
460 }
461 Inst::Time(t) => {
462 neg_args.push(Operand::Const(physical::Constant::Time(t)))
463 }
464 Inst::Duration(d) => {
465 neg_args.push(Operand::Const(physical::Constant::Duration(d)))
466 }
467 _ => return Err(anyhow!("Complex expression in negated atom")),
468 }
469 }
470 Ok(Op::Filter {
471 cond: Condition::Negation {
472 relation: predicate,
473 args: neg_args,
474 },
475 body: Box::new(body),
476 })
477 } else {
478 Err(anyhow!("NegAtom wraps non-atom"))
479 }
480 }
481 _ => Err(anyhow!("Unsupported premise type: {:?}", inst)),
482 }
483 }
484
485 fn apply_constraints(
486 &mut self,
487 args: &[InstId],
488 scan_vars: &[NameId],
489 mut body: Op,
490 ) -> Result<Op> {
491 for (i, arg) in args.iter().enumerate().rev() {
492 let scan_var = scan_vars[i];
493 let arg_inst = self.ir.get(*arg).clone();
494 match arg_inst {
495 Inst::Var(v) => {
496 if v == scan_var {
497 continue;
498 }
499 body = Op::Filter {
500 cond: Condition::Cmp {
501 op: CmpOp::Eq,
502 left: Operand::Var(scan_var),
503 right: Operand::Var(v),
504 },
505 body: Box::new(body),
506 };
507 }
508 _ => {
509 body = self.wrap_eval_check(*arg, Operand::Var(scan_var), body)?;
510 }
511 }
512 }
513 Ok(body)
514 }
515
516 fn wrap_eq_check(&mut self, l: InstId, r: InstId, body: Op) -> Result<Op> {
517 self.with_eval(l, |this, op_l| {
518 this.with_eval(r, |_this, op_r| {
519 Ok(Op::Filter {
520 cond: Condition::Cmp {
521 op: CmpOp::Eq,
522 left: op_l,
523 right: op_r,
524 },
525 body: Box::new(body),
526 })
527 })
528 })
529 }
530
531 fn wrap_eval_check(&mut self, inst: InstId, target: Operand, body: Op) -> Result<Op> {
532 self.with_eval(inst, |_this, op| {
533 Ok(Op::Filter {
534 cond: Condition::Cmp {
535 op: CmpOp::Eq,
536 left: target,
537 right: op,
538 },
539 body: Box::new(body),
540 })
541 })
542 }
543
544 fn with_eval<F>(&mut self, inst: InstId, f: F) -> Result<Op>
545 where
546 F: FnOnce(&mut Self, Operand) -> Result<Op>,
547 {
548 let i = self.ir.get(inst).clone();
549 match i {
550 Inst::Var(v) => f(self, Operand::Var(v)),
551 Inst::String(s) => f(self, Operand::Const(physical::Constant::String(s))),
552 Inst::Number(n) => f(self, Operand::Const(physical::Constant::Number(n))),
553 Inst::Name(n) => f(self, Operand::Const(physical::Constant::Name(n))),
554 Inst::Float(fl) => f(self, Operand::Const(physical::Constant::Float(fl))),
555 Inst::Time(t) => f(self, Operand::Const(physical::Constant::Time(t))),
556 Inst::Duration(d) => f(self, Operand::Const(physical::Constant::Duration(d))),
557 Inst::ApplyFn { function, args } => self.with_eval_args(
558 &args,
559 0,
560 Vec::new(),
561 Box::new(|this, ops| {
562 let tmp = this.fresh_var("call");
563 let inner = f(this, Operand::Var(tmp))?;
564 Ok(Op::Let {
565 var: tmp,
566 expr: Expr::Call {
567 function,
568 args: ops,
569 },
570 body: Box::new(inner),
571 })
572 }),
573 ),
574 Inst::List(args) => {
576 let fn_name = self.ir.intern_name("fn:list".to_string());
577 self.with_eval_args(
578 &args,
579 0,
580 Vec::new(),
581 Box::new(|this, ops| {
582 let tmp = this.fresh_var("list");
583 let inner = f(this, Operand::Var(tmp))?;
584 Ok(Op::Let {
585 var: tmp,
586 expr: Expr::Call {
587 function: fn_name,
588 args: ops,
589 },
590 body: Box::new(inner),
591 })
592 }),
593 )
594 }
595 Inst::Map { keys, values } => {
596 let mut interleaved = Vec::with_capacity(keys.len() + values.len());
598 for (k, v) in keys.iter().zip(values.iter()) {
599 interleaved.push(*k);
600 interleaved.push(*v);
601 }
602 let fn_name = self.ir.intern_name("fn:map".to_string());
603 self.with_eval_args(
604 &interleaved,
605 0,
606 Vec::new(),
607 Box::new(|this, ops| {
608 let tmp = this.fresh_var("map");
609 let inner = f(this, Operand::Var(tmp))?;
610 Ok(Op::Let {
611 var: tmp,
612 expr: Expr::Call {
613 function: fn_name,
614 args: ops,
615 },
616 body: Box::new(inner),
617 })
618 }),
619 )
620 }
621 Inst::Struct { fields, values } => {
622 let mut interleaved = Vec::with_capacity(fields.len() + values.len());
624 for (field, val) in fields.iter().zip(values.iter()) {
625 let name_inst = self.ir.add_inst(Inst::Name(*field));
627 interleaved.push(name_inst);
628 interleaved.push(*val);
629 }
630 let fn_name = self.ir.intern_name("fn:struct".to_string());
631 self.with_eval_args(
632 &interleaved,
633 0,
634 Vec::new(),
635 Box::new(|this, ops| {
636 let tmp = this.fresh_var("struct");
637 let inner = f(this, Operand::Var(tmp))?;
638 Ok(Op::Let {
639 var: tmp,
640 expr: Expr::Call {
641 function: fn_name,
642 args: ops,
643 },
644 body: Box::new(inner),
645 })
646 }),
647 )
648 }
649 _ => Err(anyhow!("Unsupported expression in evaluation")),
650 }
651 }
652
653 fn inst_to_expr<F>(&mut self, inst: InstId, f: F) -> Result<Op>
654 where
655 F: FnOnce(&mut Self, Expr) -> Result<Op>,
656 {
657 let i = self.ir.get(inst).clone();
658 match i {
659 Inst::ApplyFn { function, args } => self.with_eval_args(
660 &args,
661 0,
662 Vec::new(),
663 Box::new(|this, ops| {
664 f(
665 this,
666 Expr::Call {
667 function,
668 args: ops,
669 },
670 )
671 }),
672 ),
673 _ => self.with_eval(inst, |this, op| f(this, Expr::Value(op))),
674 }
675 }
676
677 fn with_eval_args(
678 &mut self,
679 args: &[InstId],
680 index: usize,
681 mut acc: Vec<Operand>,
682 f: Box<dyn FnOnce(&mut Self, Vec<Operand>) -> Result<Op> + '_>,
683 ) -> Result<Op> {
684 if index >= args.len() {
685 return f(self, acc);
686 }
687 self.with_eval(args[index], |this, op| {
688 acc.push(op);
689 this.with_eval_args(args, index + 1, acc, f)
690 })
691 }
692
693 fn plan_head_insert(
694 &mut self,
695 head: InstId,
696 _bound_vars: &mut FxHashSet<NameId>,
697 ) -> Result<Op> {
698 let inst = self.ir.get(head).clone();
699 if let Inst::Atom { predicate, args } = inst {
700 self.with_eval_args(
701 &args,
702 0,
703 Vec::new(),
704 Box::new(|_this, ops| {
705 Ok(Op::Insert {
706 relation: predicate,
707 args: ops,
708 })
709 }),
710 )
711 } else {
712 Err(anyhow!("Head must be an atom"))
713 }
714 }
715
716 fn get_transform_app_args(&self, t_id: InstId) -> Result<Vec<InstId>> {
717 if let Inst::Transform { app, .. } = self.ir.get(t_id)
718 && let Inst::ApplyFn { args, .. } = self.ir.get(*app)
719 {
720 return Ok(args.clone());
721 }
722 Err(anyhow!("Invalid transform structure"))
723 }
724
725 fn try_parse_aggregate(&mut self, t_id: InstId) -> Result<Option<Aggregate>> {
726 let inst = self.ir.get(t_id).clone();
727 if let Inst::Transform {
728 var: Some(var),
729 app,
730 } = inst
731 && let Inst::ApplyFn { function, args } = self.ir.get(app).clone()
732 {
733 let func_name = self.ir.resolve_name(function);
734 if matches!(
735 func_name,
736 "fn:sum"
737 | "fn:count"
738 | "fn:max"
739 | "fn:min"
740 | "fn:collect"
741 | "fn:float:sum"
742 | "fn:float:max"
743 | "fn:float:min"
744 ) {
745 let mut op_args = Vec::new();
746 for arg in args {
747 let arg_inst = self.ir.get(arg).clone();
748 match arg_inst {
749 Inst::Var(v) => op_args.push(Operand::Var(v)),
750 Inst::Number(n) => {
751 op_args.push(Operand::Const(physical::Constant::Number(n)))
752 }
753 Inst::Float(fl) => {
754 op_args.push(Operand::Const(physical::Constant::Float(fl)))
755 }
756 _ => {
757 return Err(anyhow!(
758 "Complex expressions in aggregates not supported yet"
759 ));
760 }
761 }
762 }
763 return Ok(Some(Aggregate {
764 var,
765 func: function,
766 args: op_args,
767 }));
768 }
769 }
770 Ok(None)
771 }
772}