egglog_core_relations/free_join/
plan.rs

1use std::{collections::BTreeMap, iter, mem, sync::Arc};
2
3use crate::numeric_id::{DenseIdMap, NumericId};
4use fixedbitset::FixedBitSet;
5use smallvec::{SmallVec, smallvec};
6
7use crate::{
8    common::{HashMap, HashSet, IndexSet},
9    offsets::Subset,
10    pool::Pooled,
11    query::{Atom, Query},
12    table_spec::Constraint,
13};
14
15use super::{ActionId, AtomId, ColumnId, SubAtom, VarInfo, Variable};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub(crate) struct ScanSpec {
19    pub to_index: SubAtom,
20    // Only yield rows where the given constraints match.
21    pub constraints: Vec<Constraint>,
22}
23
24#[derive(Clone, Debug, PartialEq, Eq)]
25pub(crate) struct SingleScanSpec {
26    pub atom: AtomId,
27    pub column: ColumnId,
28    pub cs: Vec<Constraint>,
29}
30
31/// Join headers evaluate constraints on a single atom; they prune the search space before the rest
32/// of the join plan is executed.
33pub(crate) struct JoinHeader {
34    pub atom: AtomId,
35    /// We currently aren't using these at all. The plan is to use this to
36    /// dedup plan stages later (it also helps for debugging).
37    #[allow(unused)]
38    pub constraints: Pooled<Vec<Constraint>>,
39    /// A pre-computed table subset that we can use to filter the table,
40    /// given these constaints.
41    ///
42    /// Why use the constraints at all? Because we want to use them to
43    /// discover common plan nodes from different queries (subsets can be
44    /// large).
45    pub subset: Subset,
46}
47
48impl std::fmt::Debug for JoinHeader {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("JoinHeader")
51            .field("atom", &self.atom)
52            .field("constraints", &self.constraints)
53            .field(
54                "subset",
55                &format_args!("Subset(size={})", self.subset.size()),
56            )
57            .finish()
58    }
59}
60
61impl Clone for JoinHeader {
62    fn clone(&self) -> Self {
63        JoinHeader {
64            atom: self.atom,
65            constraints: Pooled::cloned(&self.constraints),
66            subset: self.subset.clone(),
67        }
68    }
69}
70
71#[derive(Debug, Clone)]
72pub(crate) enum JoinStage {
73    Intersect {
74        var: Variable,
75        scans: SmallVec<[SingleScanSpec; 3]>,
76    },
77    FusedIntersect {
78        cover: ScanSpec,
79        bind: SmallVec<[(ColumnId, Variable); 2]>,
80        to_intersect: Vec<(ScanSpec, SmallVec<[ColumnId; 2]>)>,
81    },
82}
83
84impl JoinStage {
85    /// Attempt to fuse two stages into one.
86    ///
87    /// This operation is very conservative right now, it only fuses multiple
88    /// scans that do no filtering whatsoever.
89    fn fuse(&mut self, other: &JoinStage) -> bool {
90        use JoinStage::*;
91        match (self, other) {
92            (
93                FusedIntersect {
94                    cover,
95                    bind,
96                    to_intersect,
97                },
98                Intersect { var, scans },
99            ) if to_intersect.is_empty()
100                && scans.len() == 1
101                && cover.to_index.atom == scans[0].atom
102                && scans[0].cs.is_empty() =>
103            {
104                let col = scans[0].column;
105                bind.push((col, *var));
106                cover.to_index.vars.push(col);
107                true
108            }
109            (
110                x,
111                Intersect {
112                    var: var2,
113                    scans: scans2,
114                },
115            ) => {
116                // This is all somewhat mangled because of the borrowing rules
117                // when we pass &mut self into a tuple.
118                let (var1, mut scans1) = if let Intersect {
119                    var: var1,
120                    scans: scans1,
121                } = x
122                {
123                    if !(scans1.len() == 1
124                        && scans2.len() == 1
125                        && scans1[0].atom == scans2[0].atom
126                        && scans2[0].cs.is_empty())
127                    {
128                        return false;
129                    }
130                    (*var1, mem::take(scans1))
131                } else {
132                    return false;
133                };
134                let atom = scans1[0].atom;
135                let col1 = scans1[0].column;
136                let col2 = scans2[0].column;
137                *x = FusedIntersect {
138                    cover: ScanSpec {
139                        to_index: SubAtom {
140                            atom,
141                            vars: smallvec![col1, col2],
142                        },
143                        constraints: mem::take(&mut scans1[0].cs),
144                    },
145                    bind: smallvec![(col1, var1), (col2, *var2)],
146                    to_intersect: Default::default(),
147                };
148                true
149            }
150            _ => false,
151        }
152    }
153}
154
155#[derive(Debug, Clone)]
156pub(crate) struct Plan {
157    pub atoms: Arc<DenseIdMap<AtomId, Atom>>,
158    pub stages: JoinStages,
159}
160
161#[derive(Debug, Clone)]
162pub(crate) struct JoinStages {
163    pub header: Vec<JoinHeader>,
164    pub instrs: Arc<Vec<JoinStage>>,
165    pub actions: ActionId,
166}
167
168type VarSet = FixedBitSet;
169type AtomSet = FixedBitSet;
170
171/// The algorithm used to produce a join plan.
172#[derive(Default, Copy, Clone)]
173pub enum PlanStrategy {
174    /// Free Join: Iteratively pick the smallest atom as the cover for the next
175    /// stage, until all subatoms have been visited.
176    PureSize,
177
178    /// Free Join: Pick an approximate minimal set of covers, then order those
179    /// covers in increasing order of size.
180    ///
181    /// This is similar to PureSize but we first limit the potential atoms that
182    /// can act as covers so as to minimize the total number of stages in the
183    /// plan. This is only an approximate minimum: the problem of finding the
184    /// exact minimum ("set cover") is NP-hard.
185    MinCover,
186
187    /// Generate a plan for the classic Generic Join algorithm, constraining a
188    /// single variable per stage.
189    #[default]
190    Gj,
191}
192
193pub(crate) fn plan_query(query: Query) -> Plan {
194    Planner::new(&query.var_info, &query.atoms).plan(query.plan_strategy, query.action)
195}
196
197struct Planner<'a> {
198    // immutable
199    vars: &'a DenseIdMap<Variable, VarInfo>,
200    atoms: &'a DenseIdMap<AtomId, Atom>,
201
202    // mutable
203    used: VarSet,
204    constrained: AtomSet,
205
206    scratch_subatom: HashMap<AtomId, SmallVec<[ColumnId; 2]>>,
207}
208
209/// StageInfo is an intermediate stage used to describe the ordering of
210/// operations. One of these contains enough information to "expand" it to a
211/// JoinStage, but it still contains variable information.
212///
213/// This separation makes it easier for us to iterate with different planning
214/// algorithms while sharing the same "backend" that generates a concrete plan.
215struct StageInfo {
216    cover: SubAtom,
217    vars: SmallVec<[Variable; 1]>,
218    filters: Vec<(
219        SubAtom,                 /* the subatom to index */
220        SmallVec<[ColumnId; 2]>, /* how to build a key for that index from the cover atom */
221    )>,
222}
223
224impl<'a> Planner<'a> {
225    pub(crate) fn new(
226        vars: &'a DenseIdMap<Variable, VarInfo>,
227        atoms: &'a DenseIdMap<AtomId, Atom>,
228    ) -> Self {
229        Planner {
230            vars,
231            atoms,
232            used: VarSet::with_capacity(vars.n_ids()),
233            constrained: AtomSet::with_capacity(atoms.n_ids()),
234            scratch_subatom: Default::default(),
235        }
236    }
237
238    fn plan_free_join(
239        &mut self,
240        strat: PlanStrategy,
241        remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
242        stages: &mut Vec<JoinStage>,
243    ) {
244        let mut size_info = Vec::<(AtomId, usize)>::new();
245        match strat {
246            PlanStrategy::PureSize => {
247                for (atom, (size, _)) in remaining_constraints.iter() {
248                    size_info.push((atom, *size));
249                }
250            }
251            PlanStrategy::MinCover => {
252                let mut eligible_covers = HashSet::default();
253                let mut queue = BucketQueue::new(self.vars, self.atoms);
254                while let Some(atom) = queue.pop_min() {
255                    eligible_covers.insert(atom);
256                }
257                for (atom, (size, _)) in remaining_constraints
258                    .iter()
259                    .filter(|(atom, _)| eligible_covers.contains(atom))
260                {
261                    size_info.push((atom, *size));
262                }
263            }
264            PlanStrategy::Gj => unreachable!(),
265        };
266        size_info.sort_by_key(|(_, size)| *size);
267        let mut atoms = size_info.iter().map(|(atom, _)| *atom);
268        while let Some(info) = self.get_next_freejoin_stage(&mut atoms) {
269            stages.push(self.compile_stage(info))
270        }
271    }
272
273    fn plan_gj(
274        &mut self,
275        remaining_constraints: &DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)>,
276        stages: &mut Vec<JoinStage>,
277    ) {
278        // First, map all variables to the size of the smallest atom in which they appear:
279        let mut min_sizes = Vec::with_capacity(self.vars.n_ids());
280        let mut atoms_hit = AtomSet::with_capacity(self.atoms.n_ids());
281        for (var, var_info) in self.vars.iter() {
282            let n_occs = var_info.occurrences.len();
283            if n_occs == 1 && !var_info.used_in_rhs {
284                // Do not plan this one. Unless (see below).
285                continue;
286            }
287            if let Some(min_size) = var_info
288                .occurrences
289                .iter()
290                .map(|subatom| {
291                    atoms_hit.set(subatom.atom.index(), true);
292                    remaining_constraints[subatom.atom].0
293                })
294                .min()
295            {
296                min_sizes.push((var, min_size, n_occs));
297            }
298            // If the variable has no ocurrences, it may be bound on the RHS of a
299            // rule (or it may just be unused). Either way, we will ignore it when
300            // planning the query.
301        }
302        for (var, var_info) in self.vars.iter() {
303            if var_info.occurrences.len() == 1 && !var_info.used_in_rhs {
304                // We skipped this variable the first time around because it
305                // looks "unused". If it belongs to an atom that otherwise has
306                // gone unmentioned, though, we need to plan it anyway.
307                let atom = var_info.occurrences[0].atom;
308                if !atoms_hit.contains(atom.index()) {
309                    min_sizes.push((var, remaining_constraints[atom].0, 1));
310                }
311            }
312        }
313        // Sort ascending by size, then descending by number of occurrences.
314        min_sizes.sort_by_key(|(_, size, occs)| (*size, -(*occs as i64)));
315        for (var, _, _) in min_sizes {
316            let occ = self.vars[var].occurrences[0].clone();
317            let mut info = StageInfo {
318                cover: occ,
319                vars: smallvec![var],
320                filters: Default::default(),
321            };
322            for occ in &self.vars[var].occurrences[1..] {
323                info.filters
324                    .push((occ.clone(), smallvec![ColumnId::new(0)]));
325            }
326            let next_stage = self.compile_stage(info);
327            if let Some(prev) = stages.last_mut() {
328                if prev.fuse(&next_stage) {
329                    continue;
330                }
331            }
332            stages.push(next_stage);
333        }
334    }
335
336    pub(crate) fn plan(&mut self, strat: PlanStrategy, actions: ActionId) -> Plan {
337        let mut instrs = Vec::new();
338        let mut header = Vec::new();
339        self.used.clear();
340        self.constrained.clear();
341        let mut remaining_constraints: DenseIdMap<AtomId, (usize, &Pooled<Vec<Constraint>>)> =
342            Default::default();
343        // First, plan all the constants:
344        for (atom, atom_info) in self.atoms.iter() {
345            remaining_constraints.insert(
346                atom,
347                (
348                    atom_info.constraints.approx_size(),
349                    &atom_info.constraints.slow,
350                ),
351            );
352            if atom_info.constraints.fast.is_empty() {
353                continue;
354            }
355            header.push(JoinHeader {
356                atom,
357                constraints: Pooled::cloned(&atom_info.constraints.fast),
358                subset: atom_info.constraints.subset.clone(),
359            });
360        }
361        match strat {
362            PlanStrategy::PureSize | PlanStrategy::MinCover => {
363                self.plan_free_join(strat, &remaining_constraints, &mut instrs);
364            }
365            PlanStrategy::Gj => {
366                self.plan_gj(&remaining_constraints, &mut instrs);
367            }
368        }
369        Plan {
370            atoms: self.atoms.clone().into(),
371            stages: JoinStages {
372                header,
373                instrs: Arc::new(instrs),
374                actions,
375            },
376        }
377    }
378
379    fn get_next_freejoin_stage(
380        &mut self,
381        ordering: &mut impl Iterator<Item = AtomId>,
382    ) -> Option<StageInfo> {
383        loop {
384            let mut covered = false;
385            let mut filters = Vec::new();
386            let atom = ordering.next()?;
387            let atom_info = &self.atoms[atom];
388            let mut cover = SubAtom::new(atom);
389            let mut vars = SmallVec::<[Variable; 1]>::new();
390            for (ix, var) in atom_info.column_to_var.iter() {
391                if self.used.contains(var.index()) {
392                    continue;
393                }
394                // This atom is not completely covered by previous stages.
395                covered = true;
396                self.used.insert(var.index());
397                vars.push(*var);
398                cover.vars.push(ix);
399                for subatom in self.vars[*var].occurrences.iter() {
400                    if subatom.atom == atom {
401                        continue;
402                    }
403                    self.scratch_subatom
404                        .entry(subatom.atom)
405                        .or_default()
406                        .extend(subatom.vars.iter().copied());
407                }
408            }
409            if !covered {
410                // Search the next atom.
411                continue;
412            }
413            for (atom, cols) in self.scratch_subatom.drain() {
414                let mut form_key = SmallVec::<[ColumnId; 2]>::new();
415                for var_ix in &cols {
416                    let var = self.atoms[atom].column_to_var[*var_ix];
417                    // form_key is an index _into the subatom forming the cover_.
418                    let cover_col = vars
419                        .iter()
420                        .enumerate()
421                        .find(|(_, v)| **v == var)
422                        .map(|(ix, _)| ix)
423                        .unwrap();
424                    form_key.push(ColumnId::from_usize(cover_col));
425                }
426                filters.push((SubAtom { atom, vars: cols }, form_key));
427            }
428            return Some(StageInfo {
429                cover,
430                vars,
431                filters,
432            });
433        }
434    }
435
436    fn compile_stage(
437        &mut self,
438        StageInfo {
439            cover,
440            vars,
441            filters,
442        }: StageInfo,
443    ) -> JoinStage {
444        if vars.len() == 1 {
445            debug_assert!(
446                filters
447                    .iter()
448                    .all(|(_, x)| x.len() == 1 && x[0] == ColumnId::new(0)),
449                "filters={filters:?}"
450            );
451            let scans = SmallVec::<[SingleScanSpec; 3]>::from_iter(
452                iter::once(&cover)
453                    .chain(filters.iter().map(|(x, _)| x))
454                    .map(|subatom| {
455                        let atom = subatom.atom;
456                        SingleScanSpec {
457                            atom,
458                            column: subatom.vars[0],
459                            cs: if !self.constrained.put(atom.index()) {
460                                self.atoms[atom].constraints.slow.clone()
461                            } else {
462                                Default::default()
463                            },
464                        }
465                    }),
466            );
467            return JoinStage::Intersect {
468                var: vars[0],
469                scans,
470            };
471        }
472        let atom = cover.atom;
473        let cover = ScanSpec {
474            to_index: cover,
475            constraints: if !self.constrained.put(atom.index()) {
476                self.atoms[atom].constraints.slow.clone()
477            } else {
478                Default::default()
479            },
480        };
481        let mut bind = SmallVec::new();
482        let var_set = &self.atoms[atom].var_to_column;
483        for var in vars {
484            bind.push((var_set[&var], var));
485        }
486
487        let mut to_intersect = Vec::with_capacity(filters.len());
488        for (subatom, key_spec) in filters {
489            let atom = subatom.atom;
490            let scan = ScanSpec {
491                to_index: subatom,
492                constraints: if !self.constrained.put(atom.index()) {
493                    self.atoms[atom].constraints.slow.clone()
494                } else {
495                    Default::default()
496                },
497            };
498            to_intersect.push((scan, key_spec));
499        }
500
501        JoinStage::FusedIntersect {
502            cover,
503            bind,
504            to_intersect,
505        }
506    }
507}
508
509/// Datastructure used to greedily solve the set cover problem for a given free
510/// join plan.
511struct BucketQueue<'a> {
512    var_info: &'a DenseIdMap<Variable, VarInfo>,
513    cover: VarSet,
514    atom_info: DenseIdMap<AtomId, VarSet>,
515    sizes: BTreeMap<usize, IndexSet<AtomId>>,
516}
517
518impl<'a> BucketQueue<'a> {
519    fn new(var_info: &'a DenseIdMap<Variable, VarInfo>, atoms: &DenseIdMap<AtomId, Atom>) -> Self {
520        let cover = VarSet::with_capacity(var_info.n_ids());
521        let mut atom_info = DenseIdMap::with_capacity(atoms.n_ids());
522        let mut sizes = BTreeMap::<usize, IndexSet<AtomId>>::new();
523        for (id, atom) in atoms.iter() {
524            let mut bitset = VarSet::with_capacity(var_info.n_ids());
525            for (_, var) in atom.column_to_var.iter() {
526                bitset.insert(var.index());
527            }
528            sizes.entry(bitset.count_ones(..)).or_default().insert(id);
529            atom_info.insert(id, bitset);
530        }
531        BucketQueue {
532            var_info,
533            cover,
534            atom_info,
535            sizes,
536        }
537    }
538
539    /// Return the atom with the largest number of uncovered variables. A
540    /// variable is "covered" if a previous call to `pop_min` returned an atom
541    /// referencing that variable.
542    fn pop_min(&mut self) -> Option<AtomId> {
543        // Pick an arbitrary atom from the smallest bucket.
544        let (_, atoms) = self.sizes.iter_mut().next_back()?;
545        let res = atoms.pop().unwrap();
546        let vars = self.atom_info[res].clone();
547        // For each variable that we added to the cover, remove it from the
548        // entries in atom_info referencing it and update `sizes` to reflect the
549        // new ordering.
550        for new_var in vars.difference(&self.cover).map(Variable::from_usize) {
551            for subatom in &self.var_info[new_var].occurrences {
552                let cur_set = &mut self.atom_info[subatom.atom];
553                let old_size = cur_set.count_ones(..);
554                cur_set.difference_with(&vars);
555                let new_size = cur_set.count_ones(..);
556                if old_size == new_size {
557                    continue;
558                }
559                if let Some(old_size_set) = self.sizes.get_mut(&old_size) {
560                    old_size_set.swap_remove(&subatom.atom);
561                    if old_size_set.is_empty() {
562                        self.sizes.remove(&old_size);
563                    }
564                }
565                if new_size > 0 {
566                    self.sizes.entry(new_size).or_default().insert(subatom.atom);
567                }
568            }
569        }
570        self.cover.union_with(&vars);
571        Some(res)
572    }
573}