use crate::{Error, Result};
pub trait UpperBound {
fn ceiling(&self, relations: &[u64], equality_predicates: &[(usize, usize)]) -> u64;
}
pub struct ProductBound;
impl UpperBound for ProductBound {
fn ceiling(&self, relations: &[u64], _eq: &[(usize, usize)]) -> u64 {
relations.iter().fold(1u64, |acc, &n| acc.saturating_mul(n))
}
}
pub struct ChainBound {
pub distinct_counts: Vec<u64>,
}
impl ChainBound {
pub fn new(distinct_counts: Vec<u64>) -> Self {
Self { distinct_counts }
}
}
impl UpperBound for ChainBound {
fn ceiling(&self, relations: &[u64], equality_predicates: &[(usize, usize)]) -> u64 {
if relations.is_empty() {
return 0;
}
if equality_predicates.is_empty() {
return ProductBound.ceiling(relations, &[]);
}
let mut bound: u128 = relations
.iter()
.fold(1u128, |acc, &n| acc.saturating_mul(n as u128));
for &(i, j) in equality_predicates {
let d_i = self.distinct_counts.get(i).copied().unwrap_or(1).max(1) as u128;
let d_j = self.distinct_counts.get(j).copied().unwrap_or(1).max(1) as u128;
let d = d_i.max(d_j);
bound /= d;
}
if bound > u64::MAX as u128 {
u64::MAX
} else {
bound as u64
}
}
}
pub struct AgmBound;
impl UpperBound for AgmBound {
fn ceiling(&self, relations: &[u64], equality_predicates: &[(usize, usize)]) -> u64 {
if relations.is_empty() {
return 0;
}
if equality_predicates.is_empty() {
return ProductBound.ceiling(relations, &[]);
}
let product: u64 = relations.iter().fold(1u64, |acc, &n| acc.saturating_mul(n));
let min_r = *relations.iter().min().unwrap_or(&0);
let max_r = *relations.iter().max().unwrap_or(&0);
product.min(min_r.saturating_mul(max_r))
}
}
pub fn clamp_estimate(estimate: f64, ceiling: u64) -> Result<u64> {
let clamped = estimate.max(0.0).min(u64::MAX as f64) as u64;
if clamped <= ceiling {
Ok(clamped)
} else {
Err(Error::LpBoundExceeded {
estimate,
ceiling: ceiling as f64,
})
}
}
pub fn saturating_clamp(estimate: f64, ceiling: u64) -> u64 {
let clamped = estimate.max(0.0).min(u64::MAX as f64) as u64;
clamped.min(ceiling)
}
#[cfg(feature = "lp_solver")]
pub struct LpJoinBound {
distinct_counts: Vec<u64>,
}
#[cfg(feature = "lp_solver")]
impl Default for LpJoinBound {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "lp_solver")]
impl LpJoinBound {
pub fn new() -> Self {
Self {
distinct_counts: Vec::new(),
}
}
pub fn with_distinct_counts(distinct_counts: Vec<u64>) -> Self {
Self { distinct_counts }
}
pub fn ceiling(&self, relations: &[u64], equality_predicates: &[(usize, usize)]) -> u64 {
self.solve(relations, equality_predicates, false)
}
pub fn ceiling_with_distinct(
&self,
relations: &[u64],
equality_predicates: &[(usize, usize)],
) -> u64 {
self.solve(relations, equality_predicates, true)
}
fn solve(
&self,
relations: &[u64],
equality_predicates: &[(usize, usize)],
use_distinct: bool,
) -> u64 {
if relations.is_empty() {
return 0;
}
if equality_predicates.is_empty() {
return ProductBound.ceiling(relations, &[]);
}
let n = relations.len();
let preds: Vec<(usize, usize)> = equality_predicates
.iter()
.copied()
.filter(|&(i, j)| i < n && j < n && i != j)
.collect();
if preds.is_empty() {
return ProductBound.ceiling(relations, &[]);
}
let components = connected_components(n, &preds);
let mut total: u128 = 1;
for comp in &components {
let ceil = self.solve_component(relations, &preds, comp, use_distinct);
total = total.saturating_mul(ceil as u128);
if total >= u64::MAX as u128 {
return u64::MAX;
}
}
total as u64
}
fn solve_component(
&self,
relations: &[u64],
all_predicates: &[(usize, usize)],
component: &[usize],
use_distinct: bool,
) -> u64 {
if component.len() == 1 {
return relations[component[0]];
}
let in_comp: std::collections::HashSet<usize> = component.iter().copied().collect();
let comp_preds: Vec<(usize, usize)> = all_predicates
.iter()
.copied()
.filter(|&(i, j)| in_comp.contains(&i) && in_comp.contains(&j))
.collect();
use good_lp::{
Expression, ProblemVariables, Solution, SolverModel, default_solver, variable,
};
let mut vars = ProblemVariables::new();
let mut var_for: std::collections::HashMap<usize, good_lp::Variable> =
std::collections::HashMap::with_capacity(component.len());
let mut objective = Expression::with_capacity(component.len());
for &r in component {
let v = vars.add(variable().min(0.0));
var_for.insert(r, v);
let row_count = relations[r];
let mut size_f = row_count as f64;
if use_distinct {
if let Some(&d) = self.distinct_counts.get(r) {
if d > 0 {
size_f = size_f.min(d as f64);
}
}
}
let coef = if size_f <= 1.0 { 0.0 } else { size_f.ln() };
objective.add_mul(coef, v);
}
let mut model = vars.minimise(&objective).using(default_solver);
for &(i, j) in &comp_preds {
let xi = var_for[&i];
let xj = var_for[&j];
let lhs: Expression = xi + xj;
model = model.with(lhs.geq(1.0));
}
match model.solve() {
Ok(sol) => {
let lp_min = sol.eval(&objective);
let raw = lp_min.exp();
if !raw.is_finite() || raw < 0.0 {
return self.fallback(relations, &comp_preds, component);
}
let raw = raw.max(1.0);
if raw >= u64::MAX as f64 {
u64::MAX
} else {
let rounded = raw.round();
let snap_eps = 1e-9_f64.max(raw.abs() * 1e-12);
if (raw - rounded).abs() <= snap_eps {
rounded as u64
} else {
raw.ceil() as u64
}
}
}
Err(_) => self.fallback(relations, &comp_preds, component),
}
}
fn fallback(
&self,
relations: &[u64],
comp_preds: &[(usize, usize)],
component: &[usize],
) -> u64 {
if comp_preds.is_empty() {
return component
.iter()
.map(|&r| relations[r])
.fold(1u64, |a, n| a.saturating_mul(n));
}
let comp_rows: Vec<u64> = component.iter().map(|&r| relations[r]).collect();
let agm = AgmBound;
agm.ceiling(&comp_rows, &[(0, 1)])
}
}
#[cfg(feature = "lp_solver")]
impl UpperBound for LpJoinBound {
fn ceiling(&self, relations: &[u64], equality_predicates: &[(usize, usize)]) -> u64 {
self.ceiling(relations, equality_predicates)
}
}
#[cfg(feature = "lp_solver")]
fn connected_components(n: usize, edges: &[(usize, usize)]) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..n).collect();
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
for &(a, b) in edges {
if a >= n || b >= n {
continue;
}
let ra = find(&mut parent, a);
let rb = find(&mut parent, b);
if ra != rb {
parent[ra] = rb;
}
}
let mut groups: std::collections::HashMap<usize, Vec<usize>> = std::collections::HashMap::new();
for v in 0..n {
let r = find(&mut parent, v);
groups.entry(r).or_default().push(v);
}
let mut out: Vec<Vec<usize>> = groups.into_values().collect();
for c in &mut out {
c.sort_unstable();
}
out.sort_by_key(|c| c[0]);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn product_bound_two_relations() {
assert_eq!(ProductBound.ceiling(&[100, 200], &[]), 20_000);
}
#[test]
fn product_bound_overflow_saturates() {
assert_eq!(ProductBound.ceiling(&[u64::MAX, 2], &[]), u64::MAX);
}
#[test]
fn product_bound_empty_relations() {
assert_eq!(ProductBound.ceiling(&[], &[]), 1);
}
#[test]
fn agm_no_predicates_falls_back_to_product() {
assert_eq!(AgmBound.ceiling(&[10, 20, 30], &[]), 10 * 20 * 30);
}
#[test]
fn agm_with_predicates_tighter_than_product() {
let r = [1_000u64, 1_000_000];
let bound = AgmBound.ceiling(&r, &[(0, 1)]);
let product = ProductBound.ceiling(&r, &[]);
assert!(bound <= product);
}
#[test]
fn clamp_within_ceiling() {
assert_eq!(clamp_estimate(500.0, 1000).unwrap(), 500);
}
#[test]
fn clamp_exceeds_ceiling_errors() {
let err = clamp_estimate(1500.0, 1000).unwrap_err();
match err {
Error::LpBoundExceeded { estimate, ceiling } => {
assert_eq!(estimate, 1500.0);
assert_eq!(ceiling, 1000.0);
}
other => panic!("wrong error variant: {other:?}"),
}
}
#[test]
fn chain_bound_tighter_than_product() {
let r = [1_000u64, 1_000];
let cb = ChainBound::new(vec![100, 100]);
let bound = cb.ceiling(&r, &[(0, 1)]);
assert_eq!(bound, 10_000);
let product = ProductBound.ceiling(&r, &[]);
assert!(bound < product);
}
#[test]
fn chain_bound_three_table_chain() {
let r = [1_000u64, 2_000, 500];
let cb = ChainBound::new(vec![100, 100, 100]);
let bound = cb.ceiling(&r, &[(0, 1), (1, 2)]);
assert_eq!(bound, 100_000);
}
#[test]
fn chain_bound_no_predicates_falls_back() {
let cb = ChainBound::new(vec![10, 20, 30]);
assert_eq!(cb.ceiling(&[10, 20, 30], &[]), 10 * 20 * 30);
}
#[test]
fn chain_bound_missing_distinct_count_defaults_to_one() {
let cb = ChainBound::new(vec![]);
let bound = cb.ceiling(&[100, 100], &[(0, 1)]);
assert_eq!(bound, 10_000); }
#[test]
fn saturating_clamp_saturates() {
assert_eq!(saturating_clamp(500.0, 1000), 500);
assert_eq!(saturating_clamp(2000.0, 1000), 1000);
assert_eq!(saturating_clamp(-5.0, 1000), 0);
assert_eq!(saturating_clamp(f64::NAN, 1000), 0);
}
}
#[cfg(all(test, feature = "lp_solver"))]
mod lp_tests {
use super::*;
#[test]
fn two_table_join_matches_principled_agm() {
let r = [1_000u64, 1_000_000u64];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &[(0, 1)]);
assert!(
(999..=1_001).contains(&bound),
"expected ≈1000, got {bound}"
);
let coarse = AgmBound.ceiling(&r, &[(0, 1)]);
assert!(
bound <= coarse,
"LP bound {bound} must not exceed coarse AGM {coarse}"
);
}
#[test]
fn triangle_strictly_tighter_than_chain_and_product() {
let r = [1_000u64, 1_000u64, 1_000u64];
let preds = [(0usize, 1usize), (1, 2), (0, 2)];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &preds);
assert!(
(31_000u64..=32_000u64).contains(&bound),
"expected ≈31_623, got {bound}"
);
let product = ProductBound.ceiling(&r, &preds);
assert!(bound < product, "LP {bound} should be < product {product}");
let cb = ChainBound::new(vec![10, 10, 10]);
let chain = cb.ceiling(&r, &preds);
assert!(
bound < chain,
"LP {bound} should be < chain {chain} on the triangle"
);
}
#[test]
fn square_strictly_tighter_than_chain_and_product() {
let r = [100u64, 100u64, 100u64, 100u64];
let preds = [(0usize, 1usize), (1, 2), (2, 3), (3, 0)];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &preds);
assert!(
(5_000..=15_000).contains(&bound),
"expected ≈10_000, got {bound}"
);
let product = ProductBound.ceiling(&r, &preds);
assert!(bound < product, "LP {bound} should be < product {product}");
let cb = ChainBound::new(vec![4, 4, 4, 4]);
let chain = cb.ceiling(&r, &preds);
assert!(
bound < chain,
"LP {bound} should be < chain {chain} on the 4-cycle"
);
}
#[test]
fn disconnected_components_multiply() {
let r = [100u64, 200, 50, 70];
let preds = [(0usize, 1usize), (2, 3)];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &preds);
assert!(
(4_900..=5_100).contains(&bound),
"expected ≈5000, got {bound}"
);
}
#[test]
fn singleton_component_contributes_row_count() {
let r = [100u64, 200, 99];
let preds = [(0usize, 1usize)];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &preds);
assert!(
(9_800..=10_000).contains(&bound),
"expected ≈9_900, got {bound}"
);
}
#[test]
fn lp_bound_dominates_product() {
let r = [37u64, 41, 43, 47, 53];
let preds = [(0usize, 1usize), (1, 2), (2, 3), (3, 4)];
let lp = LpJoinBound::new();
let bound = lp.ceiling(&r, &preds);
let product = ProductBound.ceiling(&r, &preds);
assert!(
bound <= product,
"LP bound {bound} must be ≤ product {product}"
);
}
#[test]
fn empty_relations_zero() {
let lp = LpJoinBound::new();
assert_eq!(lp.ceiling(&[], &[]), 0);
}
#[test]
fn no_predicates_returns_product() {
let lp = LpJoinBound::new();
let r = [10u64, 20, 30];
assert_eq!(lp.ceiling(&r, &[]), 6_000);
}
#[test]
fn ceiling_with_distinct_is_at_most_unconstrained() {
let r = [1_000u64, 1_000];
let preds = [(0usize, 1usize)];
let with_d = LpJoinBound::with_distinct_counts(vec![10, 10]);
let unconstrained = LpJoinBound::new();
let a = with_d.ceiling_with_distinct(&r, &preds);
let b = unconstrained.ceiling(&r, &preds);
assert!(a <= b, "distinct-aware bound {a} must be tighter than {b}");
assert!(a <= 11, "expected ≈10 with D=10, got {a}");
}
}