reddb_server/storage/query/planner/join_dp.rs
1//! Join reordering via dynamic programming — Fase 5 P7 building
2//! block.
3//!
4//! Implements the classic Selinger-style DP join ordering
5//! algorithm: enumerate every possible left-deep join tree
6//! over `n` relations using a `2^n`-sized DP table, score each
7//! candidate via the cost model, and return the cheapest plan.
8//!
9//! Mirrors PG's `joinrels.c::join_search_one_level` modulo:
10//!
11//! - **No bushy trees**: PG emits both left-deep and bushy join
12//! trees and picks the best. We only emit left-deep for
13//! simplicity — bushy adds another `2^n` to the search space.
14//! - **No worst-case fallback**: PG falls back to the GEQO
15//! genetic algorithm when n > 12 (the DP table grows past
16//! ~4 K entries). We hard-cap at `MAX_DP_RELATIONS = 10` and
17//! refuse to plan larger joins via this module — the caller
18//! keeps the legacy heuristic order for those.
19//! - **No correlation tracking**: PG threads `RelOptInfo` through
20//! the algorithm with selectivity, distinct values, and
21//! referenced columns. We just track row count + estimated
22//! rows-after-join, computed via the cost model.
23//!
24//! The module is **not yet wired** into the planner. Wiring
25//! plugs into `optimizer.rs::reorder_join_inputs` once that
26//! function exists; for now this module is callable from tests
27//! and benchmarks only.
28//!
29//! ## Algorithm
30//!
31//! Given relations `R0, R1, …, Rn-1` and a join graph that
32//! says which pairs share a join predicate:
33//!
34//! 1. Initialize `dp[singleton(i)] = (R_i.row_count, [i])` for
35//! every relation — base case is each relation by itself.
36//! 2. For each subset `S` of size 2..n, ordered by ascending
37//! size:
38//! a. For every way to split `S = L ∪ R` with L and R
39//! non-empty and `dp[L]` and `dp[R]` already computed:
40//! - Verify the split has at least one join predicate
41//! (the join graph). Cartesian products without a
42//! predicate are filtered out unless every other
43//! split is also missing one (handle disconnected
44//! graphs gracefully).
45//! - Compute the cost of joining `L` with `R` using
46//! the cost model.
47//! - If this beats `dp[S]`, replace.
48//! 3. The final answer is `dp[full_set]`.
49//!
50//! Complexity: `O(3^n)` time, `O(2^n)` space. Acceptable for
51//! n ≤ 10 (59,049 candidates, ~1024 DP entries).
52
53use std::collections::HashMap;
54
55/// Index into the caller's relation list. The DP works with
56/// abstract `RelId`s so the planner can plug arbitrary backing
57/// types (tables, subqueries, materialised CTEs).
58pub type RelId = u8;
59
60/// Bitmask over `MAX_DP_RELATIONS` relations. Bit `i` set means
61/// relation `i` is present in the subset.
62type RelMask = u16;
63
64/// Hard cap on the number of relations this DP can handle. Joins
65/// over more relations fall back to the legacy heuristic
66/// reorderer in `optimizer.rs`.
67pub const MAX_DP_RELATIONS: usize = 10;
68
69/// Estimated cardinality (number of rows) after some operation.
70pub type Cardinality = f64;
71
72/// Estimated cost (CPU + I/O units) of executing some operation.
73pub type Cost = f64;
74
75/// Per-relation statistics fed into the DP. The caller computes
76/// these from the existing `CostEstimator` / `StatsProvider`
77/// pipeline before invoking `reorder`.
78#[derive(Debug, Clone, Copy)]
79pub struct RelStats {
80 pub id: RelId,
81 pub row_count: Cardinality,
82}
83
84/// One edge in the join graph — a predicate that constrains two
85/// relations to satisfy `left.col = right.col` (or any other
86/// equi-join shape). The DP only cares that the edge exists,
87/// not the specific columns; selectivity is supplied separately.
88#[derive(Debug, Clone, Copy)]
89pub struct JoinEdge {
90 pub left: RelId,
91 pub right: RelId,
92 /// Selectivity in `[0, 1]` — fraction of the Cartesian
93 /// product surviving the join. 0.01 is a typical equi-join
94 /// on a non-unique column; 0.0001 means very selective.
95 pub selectivity: f64,
96}
97
98/// One DP cell: the cheapest plan for joining a specific subset.
99#[derive(Debug, Clone)]
100pub struct DpEntry {
101 /// Subset of relations covered by this entry, as a bitmask.
102 pub mask: RelMask,
103 /// Estimated rows produced by this join order.
104 pub rows: Cardinality,
105 /// Estimated cost to materialise these rows.
106 pub cost: Cost,
107 /// The join order as a left-deep sequence of RelId values.
108 pub order: Vec<RelId>,
109}
110
111/// Planner errors raised by the DP.
112#[derive(Debug)]
113pub enum DpError {
114 /// Caller passed more relations than the DP supports.
115 TooManyRelations { count: usize, max: usize },
116 /// The join graph is disconnected — no edge between any
117 /// pair of relations across some split. The caller should
118 /// fall back to a Cartesian-product-tolerant ordering.
119 Disconnected,
120 /// Caller passed an empty relation list.
121 Empty,
122}
123
124impl std::fmt::Display for DpError {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::TooManyRelations { count, max } => {
128 write!(
129 f,
130 "join DP only supports up to {max} relations, got {count}"
131 )
132 }
133 Self::Disconnected => write!(f, "join graph is disconnected"),
134 Self::Empty => write!(f, "join DP requires at least one relation"),
135 }
136 }
137}
138
139impl std::error::Error for DpError {}
140
141/// Find the cheapest left-deep join order for the given
142/// relations and join graph. Returns the final DP entry whose
143/// `order` field contains the relation IDs in execution order
144/// (leftmost = build / outer side, rightmost = probe / inner
145/// side of the last join).
146pub fn reorder(rels: &[RelStats], edges: &[JoinEdge]) -> Result<DpEntry, DpError> {
147 if rels.is_empty() {
148 return Err(DpError::Empty);
149 }
150 if rels.len() > MAX_DP_RELATIONS {
151 return Err(DpError::TooManyRelations {
152 count: rels.len(),
153 max: MAX_DP_RELATIONS,
154 });
155 }
156
157 // Index relations by bit position so RelId values map to
158 // bitmask offsets. The caller's RelId values may be sparse
159 // or out-of-order; we re-key here.
160 let mut by_position: Vec<&RelStats> = rels.iter().collect();
161 by_position.sort_by_key(|r| r.id);
162 let positions: HashMap<RelId, usize> = by_position
163 .iter()
164 .enumerate()
165 .map(|(i, r)| (r.id, i))
166 .collect();
167 let n = by_position.len();
168 let full_mask: RelMask = ((1u32 << n) - 1) as RelMask;
169
170 // Build adjacency lookup: for any pair (l, r) of bit
171 // positions, what's the join selectivity? Missing pair
172 // means no predicate (Cartesian product).
173 let mut adj: HashMap<(usize, usize), f64> = HashMap::new();
174 for edge in edges {
175 let Some(&l) = positions.get(&edge.left) else {
176 continue;
177 };
178 let Some(&r) = positions.get(&edge.right) else {
179 continue;
180 };
181 adj.insert((l, r), edge.selectivity);
182 adj.insert((r, l), edge.selectivity);
183 }
184
185 // DP table: maskmap from RelMask → cheapest DpEntry.
186 let mut dp: HashMap<RelMask, DpEntry> = HashMap::with_capacity(1 << n);
187
188 // Base case: each relation alone.
189 for (i, rel) in by_position.iter().enumerate() {
190 let mask: RelMask = 1 << i;
191 dp.insert(
192 mask,
193 DpEntry {
194 mask,
195 rows: rel.row_count,
196 cost: rel.row_count,
197 order: vec![rel.id],
198 },
199 );
200 }
201
202 // Fill DP for subsets of increasing size.
203 for size in 2..=n {
204 for mask in subsets_of_size(full_mask, size) {
205 let mut best: Option<DpEntry> = None;
206 // Enumerate every non-trivial split of `mask` into
207 // (left, right). Iterating the bits of `mask` and
208 // walking submasks of `mask` yields every legal pair
209 // (left, right) with left ⊆ mask, right = mask ^ left,
210 // both non-empty.
211 let mut left: RelMask = (mask - 1) & mask;
212 while left > 0 {
213 let right: RelMask = mask ^ left;
214 // Avoid duplicate work: only consider left < right
215 // (left-deep DP standard trick).
216 if left < right || left.count_ones() < right.count_ones() {
217 if let (Some(l_entry), Some(r_entry)) = (dp.get(&left), dp.get(&right)) {
218 if let Some(candidate) =
219 cost_join(l_entry, r_entry, &adj, &positions, by_position.as_slice())
220 {
221 match &best {
222 None => best = Some(candidate),
223 Some(prev) if candidate.cost < prev.cost => {
224 best = Some(candidate);
225 }
226 _ => {}
227 }
228 }
229 }
230 }
231 left = (left - 1) & mask;
232 }
233 if let Some(entry) = best {
234 dp.insert(mask, entry);
235 }
236 }
237 }
238
239 dp.remove(&full_mask).ok_or(DpError::Disconnected)
240}
241
242/// Compute the cost of joining two DP entries. Returns `None`
243/// when the join would be a Cartesian product (no predicate
244/// connects any pair across the split) AND the caller's join
245/// graph isn't fully disconnected — we prefer ordered joins
246/// over Cartesian products and only emit the latter when
247/// nothing else is possible.
248fn cost_join(
249 left: &DpEntry,
250 right: &DpEntry,
251 adj: &HashMap<(usize, usize), f64>,
252 positions: &HashMap<RelId, usize>,
253 rels: &[&RelStats],
254) -> Option<DpEntry> {
255 // Find the strongest predicate connecting any pair across
256 // the split. We use min-selectivity (most selective edge)
257 // as the join's effective filter.
258 let left_positions: Vec<usize> = mask_to_positions(left.mask);
259 let right_positions: Vec<usize> = mask_to_positions(right.mask);
260 let mut min_selectivity: Option<f64> = None;
261 for l in &left_positions {
262 for r in &right_positions {
263 if let Some(&sel) = adj.get(&(*l, *r)) {
264 min_selectivity = Some(min_selectivity.map_or(sel, |m| m.min(sel)));
265 }
266 }
267 }
268
269 // Estimate output rows: Cartesian product * predicate
270 // selectivity. Without a predicate, fall back to 1.0 (full
271 // Cartesian) — the planner accepts this only when no
272 // predicate-bearing alternative exists.
273 let selectivity = min_selectivity.unwrap_or(1.0);
274 let out_rows = left.rows * right.rows * selectivity;
275
276 // Cost model: hash-join cost ≈ build_side + probe_side +
277 // output. Pick the smaller side for build (standard hash
278 // join optimisation).
279 let build = left.rows.min(right.rows);
280 let probe = left.rows.max(right.rows);
281 let join_cost = build * 1.5 + probe + out_rows;
282 let total_cost = left.cost + right.cost + join_cost;
283
284 // Build the merged order. Left-deep convention: if `left`
285 // already represents a join chain, append `right`'s
286 // relations to the end.
287 let mut order = left.order.clone();
288 order.extend(&right.order);
289
290 // Skip plain Cartesian products when an alternative exists.
291 // The DP framework calls this for every split; the caller
292 // discards the entry by passing None when min_selectivity is
293 // missing AND the split contains more than one relation on
294 // each side. Single-relation legs are always allowed.
295 if min_selectivity.is_none() && !left_positions.is_empty() && !right_positions.is_empty() {
296 // If both sides are singletons, allow the Cartesian
297 // (the user's join graph might genuinely be empty).
298 // Otherwise reject — the DP will pick a different split.
299 if left_positions.len() > 1 || right_positions.len() > 1 {
300 return None;
301 }
302 }
303
304 let _ = (positions, rels); // currently unused — reserved for selectivity refinement
305
306 Some(DpEntry {
307 mask: left.mask | right.mask,
308 rows: out_rows,
309 cost: total_cost,
310 order,
311 })
312}
313
314/// Yield every subset of `universe` with exactly `k` bits set.
315/// Used by the DP main loop to walk subsets in size order.
316/// Iterative — generates ~`C(n, k)` masks per call.
317fn subsets_of_size(universe: RelMask, k: usize) -> Vec<RelMask> {
318 let n = (universe.count_ones() as usize).max(k);
319 let mut out = Vec::new();
320 for mask in 1..=universe {
321 if (mask as RelMask) & universe == mask as RelMask
322 && (mask as RelMask).count_ones() as usize == k
323 {
324 out.push(mask as RelMask);
325 }
326 let _ = n;
327 }
328 out
329}
330
331/// Convert a bitmask back into a list of bit positions. Used by
332/// `cost_join` to walk pairs across a split.
333fn mask_to_positions(mask: RelMask) -> Vec<usize> {
334 let mut out = Vec::with_capacity(mask.count_ones() as usize);
335 let mut m = mask;
336 let mut pos = 0;
337 while m > 0 {
338 if m & 1 == 1 {
339 out.push(pos);
340 }
341 m >>= 1;
342 pos += 1;
343 }
344 out
345}