use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use std::ops::{Add, AddAssign};
use std::rc::Rc;
use dyn_clone::DynClone;
use crate::operator::dynamic::balance::BalancerError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct VariableIndex(pub usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConstraintIndex(pub usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct Cost(pub u64);
impl Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
Cost(self.0 + other.0)
}
}
impl AddAssign<Cost> for Cost {
fn add_assign(&mut self, other: Cost) {
self.0 += other.0;
}
}
pub type VariableValue = u8;
#[derive(Debug, Clone)]
struct BacktrackFrame {
decision: (VariableIndex, VariableValue, Cost),
removed_assignments: Vec<(VariableIndex, VariableValue, Cost)>,
}
impl BacktrackFrame {
fn new(variable: VariableIndex, value: VariableValue, cost: Cost) -> Self {
Self {
decision: (variable, value, cost),
removed_assignments: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct MaxSat {
constraints: Vec<Rc<dyn HardConstraint>>,
variables: Vec<Variable>,
backtrack_stack: Vec<BacktrackFrame>,
}
pub trait HardConstraint: Debug + DynClone {
fn propagate(&self, maxsat: &mut MaxSat, affected_variables: &mut BTreeSet<VariableIndex>);
fn variables(&self) -> Vec<VariableIndex>;
}
dyn_clone::clone_trait_object!(HardConstraint);
#[derive(Debug, Clone)]
pub struct Variable {
pub name: String,
pub domain: BTreeMap<VariableValue, Cost>,
pub constraints: Vec<ConstraintIndex>,
}
impl Default for MaxSat {
fn default() -> Self {
Self::new()
}
}
impl MaxSat {
pub fn new() -> Self {
Self {
constraints: Vec::new(),
variables: Vec::new(),
backtrack_stack: Vec::new(),
}
}
pub fn variables(&self) -> &[Variable] {
&self.variables
}
pub fn validate_solution(
&self,
solution: &BTreeMap<VariableIndex, VariableValue>,
) -> Result<(), BalancerError> {
if solution.len() != self.variables.len() {
return Err(BalancerError::NoSolution);
}
for (index, variable) in self.variables.iter().enumerate() {
let var_index = VariableIndex(index);
if let Some(&value) = solution.get(&var_index) {
if !variable.domain.contains_key(&value) {
return Err(BalancerError::NoSolution);
}
} else {
return Err(BalancerError::NoSolution);
}
}
let mut temp_maxsat = self.clone();
temp_maxsat.backtrack_stack.clear();
for (var_index, &solution_value) in solution.iter() {
let variable = &mut temp_maxsat.variables[var_index.0];
let cost = variable
.domain
.get(&solution_value)
.copied()
.ok_or(BalancerError::NoSolution)?;
variable.domain.clear();
variable.domain.insert(solution_value, cost);
}
let mut affected_variables = BTreeSet::new();
for i in 0..temp_maxsat.variables.len() {
affected_variables.insert(VariableIndex(i));
}
while !affected_variables.is_empty() {
let mut new_affected_variables = BTreeSet::new();
let mut constraints_to_reevaluate = BTreeSet::new();
for variable in affected_variables.iter() {
for constraint in temp_maxsat.variables[variable.0].constraints.iter() {
constraints_to_reevaluate.insert(*constraint);
}
}
for constraint_index in constraints_to_reevaluate.iter() {
let constraint = temp_maxsat.constraints[constraint_index.0].clone();
constraint.propagate(&mut temp_maxsat, &mut new_affected_variables);
}
affected_variables = new_affected_variables;
}
if temp_maxsat
.variables
.iter()
.any(|variable| variable.domain.is_empty())
{
return Err(BalancerError::NoSolution);
}
Ok(())
}
pub fn add_variable<T>(&mut self, name: &str, domain: &BTreeMap<T, Cost>) -> VariableIndex
where
u8: From<T>,
T: TryFrom<u8, Error = ()> + Clone,
{
let index = self.variables.len();
self.variables.push(Variable {
name: name.to_string(),
domain: domain
.iter()
.map(|(value, cost)| (VariableValue::from(value.clone()), *cost))
.collect(),
constraints: Vec::new(),
});
VariableIndex(index)
}
pub fn add_constraint<C: HardConstraint + 'static>(&mut self, constraint: C) {
let index = ConstraintIndex(self.constraints.len());
for variable in constraint.variables() {
assert!(variable.0 < self.variables.len());
self.variables[variable.0].constraints.push(index);
}
self.constraints.push(Rc::new(constraint));
}
fn highest_cost_variable_assignment(&self) -> Option<(VariableIndex, VariableValue, Cost)> {
let mut highest_cost_variable: Option<(VariableIndex, Cost, VariableValue)> = None;
for (index, variable) in self.variables.iter().enumerate() {
if variable.domain.len() <= 1 {
continue;
}
for (value, cost) in variable.domain.iter() {
if highest_cost_variable.is_none() || *cost > highest_cost_variable.unwrap().1 {
highest_cost_variable = Some((VariableIndex(index), *cost, *value));
}
}
}
let (index, cost, value) = highest_cost_variable?;
Some((index, value, cost))
}
fn decision(&mut self, variable: VariableIndex, value: VariableValue, cost: Cost) {
self.variables[variable.0].domain.remove(&value).unwrap();
self.backtrack_stack
.push(BacktrackFrame::new(variable, value, cost));
}
pub fn remove_variable_assignment(
&mut self,
variable: VariableIndex,
value: VariableValue,
) -> bool {
if let Some(cost) = self.variables[variable.0].domain.remove(&value) {
if let Some(frame) = self.backtrack_stack.last_mut() {
frame.removed_assignments.push((variable, value, cost))
}
true
} else {
false
}
}
fn backtrack(
&mut self,
affected_variables: &mut BTreeSet<VariableIndex>,
) -> Result<(), BalancerError> {
if self.backtrack_stack.is_empty() {
return Err(BalancerError::NoSolution);
}
let BacktrackFrame {
decision: (decision_variable, decision_value, decision_cost),
removed_assignments,
} = self.backtrack_stack.pop().unwrap();
for (variable, value, cost) in removed_assignments.iter() {
self.variables[variable.0].domain.insert(*value, *cost);
}
let domain = self.variables[decision_variable.0].domain.clone();
for (val, _cost) in domain.into_iter() {
self.remove_variable_assignment(decision_variable, val);
}
self.variables[decision_variable.0]
.domain
.insert(decision_value, decision_cost);
affected_variables.insert(decision_variable);
Ok(())
}
pub fn solve(&mut self) -> Result<BTreeMap<VariableIndex, VariableValue>, BalancerError> {
assert!(
self.variables
.iter()
.all(|variable| !variable.domain.is_empty())
);
let mut affected_variables = BTreeSet::new();
for i in 0..self.variables.len() {
affected_variables.insert(VariableIndex(i));
}
loop {
while !affected_variables.is_empty() {
let mut new_affected_variables = BTreeSet::new();
let mut constraints_to_reevaluate = BTreeSet::new();
for variable in affected_variables.iter() {
for constraint in self.variables[variable.0].constraints.iter() {
constraints_to_reevaluate.insert(*constraint);
}
}
for constraint_index in constraints_to_reevaluate.iter() {
let constraint = self.constraints[constraint_index.0].clone();
constraint.propagate(self, &mut new_affected_variables);
}
affected_variables = new_affected_variables;
}
assert!(affected_variables.is_empty());
if self
.variables
.iter()
.any(|variable| variable.domain.is_empty())
{
self.backtrack(&mut affected_variables)?;
continue;
}
let Some((variable, value, cost)) = self.highest_cost_variable_assignment() else {
return Ok(self
.variables
.iter()
.enumerate()
.map(|(index, variable)| {
(
VariableIndex(index),
*variable.domain.iter().next().unwrap().0,
)
})
.collect());
};
self.decision(variable, value, cost);
affected_variables.insert(variable);
}
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use crate::operator::dynamic::balance::{
JoinConstraint, MaxSat, PartitioningPolicy, maxsat::Cost,
};
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn partitioning_policy_domain() -> impl Strategy<Value = BTreeMap<PartitioningPolicy, Cost>>
{
(
(1u64..1_000_000u64).prop_map(Cost), (1u64..1_000_000u64).prop_map(Cost), (1u64..1_000_000u64).prop_map(Cost), )
.prop_map(|(shard_cost, broadcast_cost, balance_cost)| {
BTreeMap::from([
(PartitioningPolicy::Shard, shard_cost),
(PartitioningPolicy::Broadcast, broadcast_cost),
(PartitioningPolicy::Balance, balance_cost),
])
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn proptest_maxsat(
variable_domains in prop::collection::vec(partitioning_policy_domain(), 2..=10),
constraint_pairs in prop::collection::vec(
(0usize..10, 0usize..10, prop::bool::ANY),
1..=20
)
) {
let num_vars = variable_domains.len();
if num_vars < 2 {
return Ok(());
}
let mut maxsat = MaxSat::new();
let mut variable_indices = Vec::new();
for (i, domain) in variable_domains.iter().enumerate() {
let var_index = maxsat.add_variable(&format!("v{}", i), domain);
variable_indices.push(var_index);
}
for (v1_idx, v2_idx, is_left_join) in constraint_pairs.iter() {
if *v1_idx < variable_indices.len()
&& *v2_idx < variable_indices.len()
&& *v1_idx != *v2_idx
{
JoinConstraint::create_join_constraint(
&mut maxsat,
variable_indices[*v1_idx],
variable_indices[*v2_idx],
*is_left_join,
);
}
}
match maxsat.solve() {
Ok(solution) => {
prop_assert!(
maxsat.validate_solution(&solution).is_ok(),
"Solution {:?} failed validation for maxsat with {} variables and constraints",
solution,
maxsat.constraints.len()
);
}
Err(_) => {
panic!("solving failed");
}
}
}
}
}
#[test]
fn test_maxsat_join1() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"left-non-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(1000)),
(PartitioningPolicy::Balance, Cost(100)),
]),
);
let v2 = maxsat.add_variable(
"right-non-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(1000)),
(PartitioningPolicy::Balance, Cost(100)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v2, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Shard.into()),
(v2, PartitioningPolicy::Shard.into())
])
);
}
#[test]
fn test_maxsat_join2() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"left-large-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000_000)),
(PartitioningPolicy::Balance, Cost(100_000)),
]),
);
let v2 = maxsat.add_variable(
"right-small-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000)),
(PartitioningPolicy::Broadcast, Cost(20_000)),
(PartitioningPolicy::Balance, Cost(100)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v2, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Balance.into()),
(v2, PartitioningPolicy::Broadcast.into())
])
);
}
#[test]
fn test_maxsat_join3() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"left1-large-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000_000)),
(PartitioningPolicy::Balance, Cost(100_000)),
]),
);
let v2 = maxsat.add_variable(
"left2-large-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000_000)),
(PartitioningPolicy::Balance, Cost(100_000)),
]),
);
let v3 = maxsat.add_variable(
"right-small-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000)),
(PartitioningPolicy::Broadcast, Cost(20_000)),
(PartitioningPolicy::Balance, Cost(100)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v3, false);
JoinConstraint::create_join_constraint(&mut maxsat, v2, v3, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Balance.into()),
(v2, PartitioningPolicy::Balance.into()),
(v3, PartitioningPolicy::Broadcast.into())
])
);
}
#[test]
fn test_maxsat_join4() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"left1-large-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000_000)),
(PartitioningPolicy::Balance, Cost(100_000)),
]),
);
let v2 = maxsat.add_variable(
"left2-large-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(10_000)),
(PartitioningPolicy::Broadcast, Cost(100_000)),
(PartitioningPolicy::Balance, Cost(20_000)),
]),
);
let v3 = maxsat.add_variable(
"right-small-skewed",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000)),
(PartitioningPolicy::Broadcast, Cost(20_000)),
(PartitioningPolicy::Balance, Cost(100)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v3, false);
JoinConstraint::create_join_constraint(&mut maxsat, v2, v3, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Balance.into()),
(v2, PartitioningPolicy::Shard.into()),
(v3, PartitioningPolicy::Broadcast.into())
])
);
}
#[test]
fn test_maxsat_join5() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"v1",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000)),
(PartitioningPolicy::Balance, Cost(1)),
]),
);
let v2 = maxsat.add_variable(
"v2",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000)),
(PartitioningPolicy::Balance, Cost(1)),
]),
);
let v3 = maxsat.add_variable(
"v3",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(1_000_000)),
(PartitioningPolicy::Broadcast, Cost(1_000)),
(PartitioningPolicy::Balance, Cost(1)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v2, false);
JoinConstraint::create_join_constraint(&mut maxsat, v1, v3, false);
JoinConstraint::create_join_constraint(&mut maxsat, v2, v3, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Broadcast.into()),
(v2, PartitioningPolicy::Shard.into()),
(v3, PartitioningPolicy::Shard.into())
])
);
}
#[test]
fn real_world1() {
let mut maxsat = MaxSat::new();
let v1 = maxsat.add_variable(
"v1",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(0)),
(PartitioningPolicy::Balance, Cost(0)),
]),
);
let v2 = maxsat.add_variable(
"v2",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(0)),
(PartitioningPolicy::Balance, Cost(0)),
]),
);
let v3 = maxsat.add_variable(
"v3",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(0)),
(PartitioningPolicy::Balance, Cost(0)),
]),
);
let v4 = maxsat.add_variable(
"v4",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(100_000)),
(PartitioningPolicy::Broadcast, Cost(100_000)),
(PartitioningPolicy::Balance, Cost(13_000)),
]),
);
let v5 = maxsat.add_variable(
"v5",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(0)),
(PartitioningPolicy::Balance, Cost(0)),
]),
);
let x = maxsat.add_variable(
"x",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(0)),
(PartitioningPolicy::Broadcast, Cost(0)),
(PartitioningPolicy::Balance, Cost(0)),
]),
);
let y = maxsat.add_variable(
"y",
&BTreeMap::from([
(PartitioningPolicy::Shard, Cost(187_000)),
(PartitioningPolicy::Broadcast, Cost(187_000)),
(PartitioningPolicy::Balance, Cost(23_000)),
]),
);
JoinConstraint::create_join_constraint(&mut maxsat, v1, x, false);
JoinConstraint::create_join_constraint(&mut maxsat, v1, y, false);
JoinConstraint::create_join_constraint(&mut maxsat, v2, x, false);
JoinConstraint::create_join_constraint(&mut maxsat, v2, y, false);
JoinConstraint::create_join_constraint(&mut maxsat, v3, x, false);
JoinConstraint::create_join_constraint(&mut maxsat, v3, y, false);
JoinConstraint::create_join_constraint(&mut maxsat, v4, x, false);
JoinConstraint::create_join_constraint(&mut maxsat, v4, y, false);
JoinConstraint::create_join_constraint(&mut maxsat, v5, x, false);
JoinConstraint::create_join_constraint(&mut maxsat, v5, y, false);
let solution = maxsat.solve().unwrap();
assert_eq!(
solution,
BTreeMap::from([
(v1, PartitioningPolicy::Broadcast.into()),
(v2, PartitioningPolicy::Broadcast.into()),
(v3, PartitioningPolicy::Broadcast.into()),
(v4, PartitioningPolicy::Broadcast.into()),
(v5, PartitioningPolicy::Broadcast.into()),
(x, PartitioningPolicy::Balance.into()),
(y, PartitioningPolicy::Balance.into()),
])
);
}
}