Skip to main content

mangle_analysis/
planner.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Result, anyhow};
16use fxhash::FxHashSet;
17use mangle_ir::physical::{self, Aggregate, CmpOp, Condition, DataSource, Expr, Op, Operand};
18use mangle_ir::{Inst, InstId, Ir, NameId};
19
20pub struct Planner<'a> {
21    ir: &'a mut Ir,
22    delta_pred: Option<NameId>,
23    fresh_counter: usize,
24}
25
26impl<'a> Planner<'a> {
27    pub fn new(ir: &'a mut Ir) -> Self {
28        Self {
29            ir,
30            delta_pred: None,
31            fresh_counter: 0,
32        }
33    }
34
35    pub fn with_delta(mut self, delta_pred: NameId) -> Self {
36        self.delta_pred = Some(delta_pred);
37        self
38    }
39
40    pub fn plan_rule(mut self, rule_id: InstId) -> Result<Op> {
41        let (head, premises, transform) = match self.ir.get(rule_id) {
42            Inst::Rule {
43                head,
44                premises,
45                transform,
46            } => (*head, premises.clone(), transform.clone()),
47            _ => return Err(anyhow!("Not a rule")),
48        };
49
50        // Split transforms into blocks by 'do' statements
51        let blocks = self.split_transforms(transform);
52        let num_blocks = blocks.len();
53
54        let mut ops = Vec::new();
55        let mut current_source: Option<(NameId, Vec<NameId>)> = None;
56        let mut bound_vars = FxHashSet::default();
57
58        for (i, block) in blocks.into_iter().enumerate() {
59            let is_last = i == num_blocks - 1;
60
61            if i == 0 {
62                // Block 0: Premises + Lets
63                if is_last {
64                    // Only one block, no aggregations
65                    let op = self.plan_join_sequence(
66                        premises.clone(),
67                        &mut bound_vars,
68                        |planner, vars| {
69                            planner.plan_transforms_sequence(&block, vars, |p, v| {
70                                p.plan_head_insert(head, v)
71                            })
72                        },
73                    )?;
74                    ops.push(op);
75                } else {
76                    // Materialize to temp
77                    let temp_rel = self.fresh_var("temp_grp");
78                    let mut capture_vars: Vec<NameId> = Vec::new(); // Will be populated by continuation
79
80                    let op = self.plan_join_sequence(
81                        premises.clone(),
82                        &mut bound_vars,
83                        |planner, vars| {
84                            planner.plan_transforms_sequence(&block, vars, |_, v| {
85                                let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
86                                sorted_vars.sort();
87                                capture_vars = sorted_vars.clone();
88                                let args =
89                                    sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
90                                Ok(Op::Insert {
91                                    relation: temp_rel,
92                                    args,
93                                })
94                            })
95                        },
96                    )?;
97                    ops.push(op);
98                    current_source = Some((temp_rel, capture_vars));
99                }
100            } else {
101                // Block i > 0: Starts with 'do'
102                let (src_rel, src_vars) = current_source.take().expect("No source for aggregation");
103
104                if is_last {
105                    let op = self.plan_block_k(src_rel, src_vars, &block, |p, v| {
106                        p.plan_head_insert(head, v)
107                    })?;
108                    ops.push(op);
109                } else {
110                    let next_temp = self.fresh_var("temp_grp");
111                    let mut next_vars: Vec<NameId> = Vec::new();
112
113                    let op = self.plan_block_k(src_rel, src_vars, &block, |_, v| {
114                        let mut sorted_vars: Vec<NameId> = v.iter().cloned().collect();
115                        sorted_vars.sort();
116                        next_vars = sorted_vars.clone();
117                        let args = sorted_vars.iter().map(|&var| Operand::Var(var)).collect();
118                        Ok(Op::Insert {
119                            relation: next_temp,
120                            args,
121                        })
122                    })?;
123                    ops.push(op);
124                    current_source = Some((next_temp, next_vars));
125                }
126            }
127        }
128
129        if ops.len() == 1 {
130            Ok(ops.remove(0))
131        } else {
132            Ok(Op::Seq(ops))
133        }
134    }
135
136    fn split_transforms(&self, transforms: Vec<InstId>) -> Vec<Vec<InstId>> {
137        let mut blocks = Vec::new();
138        let mut current = Vec::new();
139        for t in transforms {
140            let inst = self.ir.get(t);
141            if let Inst::Transform { var: None, .. } = inst {
142                blocks.push(current);
143                current = Vec::new();
144            }
145            current.push(t);
146        }
147        blocks.push(current);
148        blocks
149    }
150
151    fn plan_block_k<F>(
152        &mut self,
153        source_rel: NameId,
154        source_vars: Vec<NameId>,
155        block: &[InstId],
156        continuation: F,
157    ) -> Result<Op>
158    where
159        F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
160    {
161        let do_stmt = block[0];
162        let rest = &block[1..];
163
164        let keys_insts = self.get_transform_app_args(do_stmt)?;
165        let mut keys = Vec::new();
166        for k in keys_insts {
167            if let Inst::Var(v) = self.ir.get(k) {
168                keys.push(*v);
169            } else {
170                return Err(anyhow!("GroupBy keys must be variables"));
171            }
172        }
173
174        let mut aggregates = Vec::new();
175        let mut lets = Vec::new();
176        for &t in rest {
177            if let Some(agg) = self.try_parse_aggregate(t)? {
178                aggregates.push(agg);
179            } else {
180                lets.push(t);
181            }
182        }
183
184        let mut inner_vars = FxHashSet::default();
185        for &k in &keys {
186            inner_vars.insert(k);
187        }
188        for agg in &aggregates {
189            inner_vars.insert(agg.var);
190        }
191
192        let body = self.plan_transforms_sequence(&lets, &mut inner_vars, continuation)?;
193
194        Ok(Op::GroupBy {
195            source: source_rel,
196            vars: source_vars,
197            keys,
198            aggregates,
199            body: Box::new(body),
200        })
201    }
202
203    fn plan_transforms_sequence<F>(
204        &mut self,
205        transforms: &[InstId],
206        bound_vars: &mut FxHashSet<NameId>,
207        continuation: F,
208    ) -> Result<Op>
209    where
210        F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
211    {
212        if transforms.is_empty() {
213            return continuation(self, bound_vars);
214        }
215
216        let t_id = transforms[0];
217        let rest = &transforms[1..];
218
219        let inst = self.ir.get(t_id).clone();
220        if let Inst::Transform {
221            var: Some(var),
222            app,
223        } = inst
224        {
225            self.inst_to_expr(app, |planner, expr| {
226                bound_vars.insert(var);
227                let body = planner.plan_transforms_sequence(rest, bound_vars, continuation)?;
228                Ok(Op::Let {
229                    var,
230                    expr,
231                    body: Box::new(body),
232                })
233            })
234        } else {
235            // Should not happen if split_transforms is correct
236            Err(anyhow!("Unexpected transform in sequence"))
237        }
238    }
239
240    fn fresh_var(&mut self, prefix: &str) -> NameId {
241        let id = self.fresh_counter;
242        self.fresh_counter += 1;
243        let name = format!("${}_{}", prefix, id);
244        self.ir.intern_name(name)
245    }
246
247    fn plan_join_sequence<F>(
248        &mut self,
249        mut premises: Vec<InstId>,
250        bound_vars: &mut FxHashSet<NameId>,
251        continuation: F,
252    ) -> Result<Op>
253    where
254        F: FnOnce(&mut Self, &mut FxHashSet<NameId>) -> Result<Op>,
255    {
256        if premises.is_empty() {
257            return continuation(self, bound_vars);
258        }
259
260        let current_premise = premises.remove(0);
261        let inst = self.ir.get(current_premise).clone();
262
263        match inst {
264            Inst::Atom { predicate, args } => {
265                let mut scan_vars = Vec::new();
266                let mut new_bindings = Vec::new();
267
268                // Look for a potential index lookup
269                let mut index_lookup: Option<(usize, Operand)> = None;
270
271                for (i, arg) in args.iter().enumerate() {
272                    let arg_inst = self.ir.get(*arg).clone();
273                    match arg_inst {
274                        Inst::Var(v) if bound_vars.contains(&v) => {
275                            if index_lookup.is_none() {
276                                index_lookup = Some((i, Operand::Var(v)));
277                            }
278                        }
279                        Inst::Number(n) => {
280                            if index_lookup.is_none() {
281                                index_lookup =
282                                    Some((i, Operand::Const(physical::Constant::Number(n))));
283                            }
284                        }
285                        Inst::String(s) => {
286                            if index_lookup.is_none() {
287                                index_lookup =
288                                    Some((i, Operand::Const(physical::Constant::String(s))));
289                            }
290                        }
291                        Inst::Name(n) => {
292                            if index_lookup.is_none() {
293                                index_lookup =
294                                    Some((i, Operand::Const(physical::Constant::Name(n))));
295                            }
296                        }
297                        _ => {}
298                    }
299                }
300
301                for arg in &args {
302                    if let Inst::Var(v) = self.ir.get(*arg)
303                        && !bound_vars.contains(v)
304                    {
305                        scan_vars.push(*v);
306                        new_bindings.push(*v);
307                        continue;
308                    }
309                    let tmp = self.fresh_var("scan");
310                    scan_vars.push(tmp);
311                    new_bindings.push(tmp);
312                }
313
314                for v in &new_bindings {
315                    bound_vars.insert(*v);
316                }
317
318                let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
319                let wrapped_body = self.apply_constraints(&args, &scan_vars, body)?;
320
321                let source = if let Some((col_idx, key)) = index_lookup {
322                    DataSource::IndexLookup {
323                        relation: predicate,
324                        col_idx,
325                        key,
326                        vars: scan_vars,
327                    }
328                } else if Some(predicate) == self.delta_pred {
329                    DataSource::ScanDelta {
330                        relation: predicate,
331                        vars: scan_vars,
332                    }
333                } else {
334                    DataSource::Scan {
335                        relation: predicate,
336                        vars: scan_vars,
337                    }
338                };
339                Ok(Op::Iterate {
340                    source,
341                    body: Box::new(wrapped_body),
342                })
343            }
344            Inst::Eq(l, r) => {
345                let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
346                self.wrap_eq_check(l, r, body)
347            }
348            Inst::Ineq(l, r) => {
349                let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
350                self.with_eval(l, |this, left_op| {
351                    this.with_eval(r, |_this, right_op| {
352                        Ok(Op::Filter {
353                            cond: Condition::Cmp {
354                                op: CmpOp::Neq,
355                                left: left_op.clone(),
356                                right: right_op,
357                            },
358                            body: Box::new(body),
359                        })
360                    })
361                })
362            }
363            Inst::NegAtom(inner) => {
364                let inner_inst = self.ir.get(inner).clone();
365                if let Inst::Atom { predicate, args } = inner_inst {
366                    let body = self.plan_join_sequence(premises, bound_vars, continuation)?;
367                    let mut neg_args = Vec::new();
368                    for arg in &args {
369                        let arg_inst = self.ir.get(*arg).clone();
370                        match arg_inst {
371                            Inst::Var(v) => neg_args.push(Operand::Var(v)),
372                            Inst::Number(n) => {
373                                neg_args.push(Operand::Const(physical::Constant::Number(n)))
374                            }
375                            Inst::String(s) => {
376                                neg_args.push(Operand::Const(physical::Constant::String(s)))
377                            }
378                            Inst::Name(n) => {
379                                neg_args.push(Operand::Const(physical::Constant::Name(n)))
380                            }
381                            _ => return Err(anyhow!("Complex expression in negated atom")),
382                        }
383                    }
384                    Ok(Op::Filter {
385                        cond: Condition::Negation {
386                            relation: predicate,
387                            args: neg_args,
388                        },
389                        body: Box::new(body),
390                    })
391                } else {
392                    Err(anyhow!("NegAtom wraps non-atom"))
393                }
394            }
395            _ => Err(anyhow!("Unsupported premise type: {:?}", inst)),
396        }
397    }
398
399    fn apply_constraints(
400        &mut self,
401        args: &[InstId],
402        scan_vars: &[NameId],
403        mut body: Op,
404    ) -> Result<Op> {
405        for (i, arg) in args.iter().enumerate().rev() {
406            let scan_var = scan_vars[i];
407            let arg_inst = self.ir.get(*arg).clone();
408            match arg_inst {
409                Inst::Var(v) => {
410                    if v == scan_var {
411                        continue;
412                    }
413                    body = Op::Filter {
414                        cond: Condition::Cmp {
415                            op: CmpOp::Eq,
416                            left: Operand::Var(scan_var),
417                            right: Operand::Var(v),
418                        },
419                        body: Box::new(body),
420                    };
421                }
422                _ => {
423                    body = self.wrap_eval_check(*arg, Operand::Var(scan_var), body)?;
424                }
425            }
426        }
427        Ok(body)
428    }
429
430    fn wrap_eq_check(&mut self, l: InstId, r: InstId, body: Op) -> Result<Op> {
431        self.with_eval(l, |this, op_l| {
432            this.with_eval(r, |_this, op_r| {
433                Ok(Op::Filter {
434                    cond: Condition::Cmp {
435                        op: CmpOp::Eq,
436                        left: op_l,
437                        right: op_r,
438                    },
439                    body: Box::new(body),
440                })
441            })
442        })
443    }
444
445    fn wrap_eval_check(&mut self, inst: InstId, target: Operand, body: Op) -> Result<Op> {
446        self.with_eval(inst, |_this, op| {
447            Ok(Op::Filter {
448                cond: Condition::Cmp {
449                    op: CmpOp::Eq,
450                    left: target,
451                    right: op,
452                },
453                body: Box::new(body),
454            })
455        })
456    }
457
458    fn with_eval<F>(&mut self, inst: InstId, f: F) -> Result<Op>
459    where
460        F: FnOnce(&mut Self, Operand) -> Result<Op>,
461    {
462        let i = self.ir.get(inst).clone();
463        match i {
464            Inst::Var(v) => f(self, Operand::Var(v)),
465            Inst::String(s) => f(self, Operand::Const(physical::Constant::String(s))),
466            Inst::Number(n) => f(self, Operand::Const(physical::Constant::Number(n))),
467            Inst::Name(n) => f(self, Operand::Const(physical::Constant::Name(n))),
468            Inst::ApplyFn { function, args } => self.with_eval_args(
469                &args,
470                0,
471                Vec::new(),
472                Box::new(|this, ops| {
473                    let tmp = this.fresh_var("call");
474                    let inner = f(this, Operand::Var(tmp))?;
475                    Ok(Op::Let {
476                        var: tmp,
477                        expr: Expr::Call {
478                            function,
479                            args: ops,
480                        },
481                        body: Box::new(inner),
482                    })
483                }),
484            ),
485            _ => Err(anyhow!("Unsupported expression in evaluation")),
486        }
487    }
488
489    fn inst_to_expr<F>(&mut self, inst: InstId, f: F) -> Result<Op>
490    where
491        F: FnOnce(&mut Self, Expr) -> Result<Op>,
492    {
493        let i = self.ir.get(inst).clone();
494        match i {
495            Inst::ApplyFn { function, args } => self.with_eval_args(
496                &args,
497                0,
498                Vec::new(),
499                Box::new(|this, ops| {
500                    f(
501                        this,
502                        Expr::Call {
503                            function,
504                            args: ops,
505                        },
506                    )
507                }),
508            ),
509            _ => self.with_eval(inst, |this, op| f(this, Expr::Value(op))),
510        }
511    }
512
513    fn with_eval_args(
514        &mut self,
515        args: &[InstId],
516        index: usize,
517        mut acc: Vec<Operand>,
518        f: Box<dyn FnOnce(&mut Self, Vec<Operand>) -> Result<Op> + '_>,
519    ) -> Result<Op> {
520        if index >= args.len() {
521            return f(self, acc);
522        }
523        self.with_eval(args[index], |this, op| {
524            acc.push(op);
525            this.with_eval_args(args, index + 1, acc, f)
526        })
527    }
528
529    fn plan_head_insert(
530        &mut self,
531        head: InstId,
532        _bound_vars: &mut FxHashSet<NameId>,
533    ) -> Result<Op> {
534        let inst = self.ir.get(head).clone();
535        if let Inst::Atom { predicate, args } = inst {
536            self.with_eval_args(
537                &args,
538                0,
539                Vec::new(),
540                Box::new(|_this, ops| {
541                    Ok(Op::Insert {
542                        relation: predicate,
543                        args: ops,
544                    })
545                }),
546            )
547        } else {
548            Err(anyhow!("Head must be an atom"))
549        }
550    }
551
552    fn get_transform_app_args(&self, t_id: InstId) -> Result<Vec<InstId>> {
553        if let Inst::Transform { app, .. } = self.ir.get(t_id)
554            && let Inst::ApplyFn { args, .. } = self.ir.get(*app)
555        {
556            return Ok(args.clone());
557        }
558        Err(anyhow!("Invalid transform structure"))
559    }
560
561    fn try_parse_aggregate(&mut self, t_id: InstId) -> Result<Option<Aggregate>> {
562        let inst = self.ir.get(t_id).clone();
563        if let Inst::Transform {
564            var: Some(var),
565            app,
566        } = inst
567            && let Inst::ApplyFn { function, args } = self.ir.get(app).clone()
568        {
569            let func_name = self.ir.resolve_name(function);
570            if matches!(
571                func_name,
572                "fn:sum" | "fn:count" | "fn:max" | "fn:min" | "fn:collect"
573            ) {
574                let mut op_args = Vec::new();
575                for arg in args {
576                    let arg_inst = self.ir.get(arg).clone();
577                    match arg_inst {
578                        Inst::Var(v) => op_args.push(Operand::Var(v)),
579                        Inst::Number(n) => {
580                            op_args.push(Operand::Const(physical::Constant::Number(n)))
581                        }
582                        _ => {
583                            return Err(anyhow!(
584                                "Complex expressions in aggregates not supported yet"
585                            ));
586                        }
587                    }
588                }
589                return Ok(Some(Aggregate {
590                    var,
591                    func: function,
592                    args: op_args,
593                }));
594            }
595        }
596        Ok(None)
597    }
598}