use std::collections::HashMap;
pub type RelId = u8;
type RelMask = u16;
pub const MAX_DP_RELATIONS: usize = 10;
pub type Cardinality = f64;
pub type Cost = f64;
#[derive(Debug, Clone, Copy)]
pub struct RelStats {
pub id: RelId,
pub row_count: Cardinality,
}
#[derive(Debug, Clone, Copy)]
pub struct JoinEdge {
pub left: RelId,
pub right: RelId,
pub selectivity: f64,
}
#[derive(Debug, Clone)]
pub struct DpEntry {
pub mask: RelMask,
pub rows: Cardinality,
pub cost: Cost,
pub order: Vec<RelId>,
}
#[derive(Debug)]
pub enum DpError {
TooManyRelations { count: usize, max: usize },
Disconnected,
Empty,
}
impl std::fmt::Display for DpError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyRelations { count, max } => {
write!(
f,
"join DP only supports up to {max} relations, got {count}"
)
}
Self::Disconnected => write!(f, "join graph is disconnected"),
Self::Empty => write!(f, "join DP requires at least one relation"),
}
}
}
impl std::error::Error for DpError {}
pub fn reorder(rels: &[RelStats], edges: &[JoinEdge]) -> Result<DpEntry, DpError> {
if rels.is_empty() {
return Err(DpError::Empty);
}
if rels.len() > MAX_DP_RELATIONS {
return Err(DpError::TooManyRelations {
count: rels.len(),
max: MAX_DP_RELATIONS,
});
}
let mut by_position: Vec<&RelStats> = rels.iter().collect();
by_position.sort_by_key(|r| r.id);
let positions: HashMap<RelId, usize> = by_position
.iter()
.enumerate()
.map(|(i, r)| (r.id, i))
.collect();
let n = by_position.len();
let full_mask: RelMask = ((1u32 << n) - 1) as RelMask;
let mut adj: HashMap<(usize, usize), f64> = HashMap::new();
for edge in edges {
let Some(&l) = positions.get(&edge.left) else {
continue;
};
let Some(&r) = positions.get(&edge.right) else {
continue;
};
adj.insert((l, r), edge.selectivity);
adj.insert((r, l), edge.selectivity);
}
let mut dp: HashMap<RelMask, DpEntry> = HashMap::with_capacity(1 << n);
for (i, rel) in by_position.iter().enumerate() {
let mask: RelMask = 1 << i;
dp.insert(
mask,
DpEntry {
mask,
rows: rel.row_count,
cost: rel.row_count,
order: vec![rel.id],
},
);
}
for size in 2..=n {
for mask in subsets_of_size(full_mask, size) {
let mut best: Option<DpEntry> = None;
let mut left: RelMask = (mask - 1) & mask;
while left > 0 {
let right: RelMask = mask ^ left;
if left < right || left.count_ones() < right.count_ones() {
if let (Some(l_entry), Some(r_entry)) = (dp.get(&left), dp.get(&right)) {
if let Some(candidate) =
cost_join(l_entry, r_entry, &adj, &positions, by_position.as_slice())
{
match &best {
None => best = Some(candidate),
Some(prev) if candidate.cost < prev.cost => {
best = Some(candidate);
}
_ => {}
}
}
}
}
left = (left - 1) & mask;
}
if let Some(entry) = best {
dp.insert(mask, entry);
}
}
}
dp.remove(&full_mask).ok_or(DpError::Disconnected)
}
fn cost_join(
left: &DpEntry,
right: &DpEntry,
adj: &HashMap<(usize, usize), f64>,
positions: &HashMap<RelId, usize>,
rels: &[&RelStats],
) -> Option<DpEntry> {
let left_positions: Vec<usize> = mask_to_positions(left.mask);
let right_positions: Vec<usize> = mask_to_positions(right.mask);
let mut min_selectivity: Option<f64> = None;
for l in &left_positions {
for r in &right_positions {
if let Some(&sel) = adj.get(&(*l, *r)) {
min_selectivity = Some(min_selectivity.map_or(sel, |m| m.min(sel)));
}
}
}
let selectivity = min_selectivity.unwrap_or(1.0);
let out_rows = left.rows * right.rows * selectivity;
let build = left.rows.min(right.rows);
let probe = left.rows.max(right.rows);
let join_cost = build * 1.5 + probe + out_rows;
let total_cost = left.cost + right.cost + join_cost;
let mut order = left.order.clone();
order.extend(&right.order);
if min_selectivity.is_none() && !left_positions.is_empty() && !right_positions.is_empty() {
if left_positions.len() > 1 || right_positions.len() > 1 {
return None;
}
}
let _ = (positions, rels);
Some(DpEntry {
mask: left.mask | right.mask,
rows: out_rows,
cost: total_cost,
order,
})
}
fn subsets_of_size(universe: RelMask, k: usize) -> Vec<RelMask> {
let n = (universe.count_ones() as usize).max(k);
let mut out = Vec::new();
for mask in 1..=universe {
if (mask as RelMask) & universe == mask as RelMask
&& (mask as RelMask).count_ones() as usize == k
{
out.push(mask as RelMask);
}
let _ = n;
}
out
}
fn mask_to_positions(mask: RelMask) -> Vec<usize> {
let mut out = Vec::with_capacity(mask.count_ones() as usize);
let mut m = mask;
let mut pos = 0;
while m > 0 {
if m & 1 == 1 {
out.push(pos);
}
m >>= 1;
pos += 1;
}
out
}