use crate::constraint::{AllDifferentExcept, ConstraintEnum};
use crate::domain::CostFiniteDomain;
use crate::{Csp, OptimizationMode, Pruning, SolveConfig, SolveStats};
pub const SENTINEL: i32 = -1;
const DEFAULT_NODE_BUDGET: u64 = 1_000_000;
#[derive(Debug, Default)]
pub struct AssignmentBuilder {
n_rows: usize,
n_cols: usize,
cost_matrix: Vec<f64>,
row_groups: Vec<u8>,
col_groups: Vec<u8>,
pins: Vec<(usize, i32)>,
unmatch_penalty: f64,
node_budget: Option<u64>,
cost_set: bool,
}
#[derive(Debug, Clone)]
pub struct AssignmentSolution {
pub assign: Vec<i32>,
pub cost: f64,
pub stats: SolveStats,
}
#[derive(Debug)]
pub enum AssignmentError {
DimensionsNotSet,
CostNotSet,
GroupLengthMismatch,
InvalidPin {
row: usize,
col: i32,
},
Infeasible,
}
impl std::fmt::Display for AssignmentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DimensionsNotSet => {
write!(f, "AssignmentBuilder: .rows() and .cols() must both be set to a non-zero value before .solve()")
}
Self::CostNotSet => {
write!(f, "AssignmentBuilder: .cost() must be called before .solve()")
}
Self::GroupLengthMismatch => {
write!(f, "AssignmentBuilder: row_groups / col_groups length does not match the declared dimensions")
}
Self::InvalidPin { row, col } => {
write!(
f,
"AssignmentBuilder: invalid pin (row={row}, col={col}); col must be SENTINEL or a valid 0..n_cols index sharing the row's group"
)
}
Self::Infeasible => {
write!(f, "AssignmentBuilder: CSP is infeasible under the supplied constraints")
}
}
}
}
impl std::error::Error for AssignmentError {}
pub fn assignment() -> AssignmentBuilder {
AssignmentBuilder::default()
}
impl AssignmentBuilder {
pub fn rows(mut self, n: usize) -> Self {
self.n_rows = n;
self
}
pub fn cols(mut self, n: usize) -> Self {
self.n_cols = n;
self
}
pub fn cost(mut self, f: impl Fn(usize, usize) -> f64) -> Self {
assert!(
self.n_rows > 0 && self.n_cols > 0,
"AssignmentBuilder::cost() requires .rows() and .cols() to be set first"
);
let mut matrix = Vec::with_capacity(self.n_rows * self.n_cols);
for i in 0..self.n_rows {
for k in 0..self.n_cols {
matrix.push(f(i, k));
}
}
self.cost_matrix = matrix;
self.cost_set = true;
self
}
pub fn row_group(mut self, f: impl Fn(usize) -> u8) -> Self {
self.row_groups = (0..self.n_rows).map(f).collect();
self
}
pub fn col_group(mut self, f: impl Fn(usize) -> u8) -> Self {
self.col_groups = (0..self.n_cols).map(f).collect();
self
}
pub fn pin(mut self, row: usize, col: i32) -> Self {
self.pins.push((row, col));
self
}
pub fn unmatch_penalty(mut self, penalty: f64) -> Self {
self.unmatch_penalty = penalty;
self
}
pub fn node_budget(mut self, budget: Option<u64>) -> Self {
self.node_budget = budget;
self
}
pub fn solve(self) -> Result<AssignmentSolution, AssignmentError> {
if self.n_rows == 0 || self.n_cols == 0 {
return Err(AssignmentError::DimensionsNotSet);
}
if !self.cost_set {
return Err(AssignmentError::CostNotSet);
}
let row_groups: Vec<u8> = if self.row_groups.is_empty() {
vec![0; self.n_rows]
} else if self.row_groups.len() == self.n_rows {
self.row_groups
} else {
return Err(AssignmentError::GroupLengthMismatch);
};
let col_groups: Vec<u8> = if self.col_groups.is_empty() {
vec![0; self.n_cols]
} else if self.col_groups.len() == self.n_cols {
self.col_groups
} else {
return Err(AssignmentError::GroupLengthMismatch);
};
let mut row_pin: Vec<Option<i32>> = vec![None; self.n_rows];
for &(row, col) in &self.pins {
if row >= self.n_rows {
return Err(AssignmentError::InvalidPin { row, col });
}
if col != SENTINEL && (col < 0 || col as usize >= self.n_cols) {
return Err(AssignmentError::InvalidPin { row, col });
}
if col != SENTINEL && col_groups[col as usize] != row_groups[row] {
return Err(AssignmentError::InvalidPin { row, col });
}
match row_pin[row] {
None => row_pin[row] = Some(col),
Some(prev) if prev == col => {} Some(_) => return Err(AssignmentError::Infeasible),
}
}
let mut csp: Csp<CostFiniteDomain> = Csp::new();
let mut row_var_ids: Vec<u32> = Vec::with_capacity(self.n_rows);
for i in 0..self.n_rows {
let row_group = row_groups[i];
let row_offset = i * self.n_cols;
let mut values: Vec<i32> = Vec::with_capacity(self.n_cols + 1);
let mut costs: Vec<f64> = Vec::with_capacity(self.n_cols + 1);
match row_pin[i] {
Some(SENTINEL) => {
values.push(SENTINEL);
costs.push(self.unmatch_penalty);
}
Some(col) => {
values.push(col);
costs.push(self.cost_matrix[row_offset + col as usize]);
}
None => {
values.push(SENTINEL);
costs.push(self.unmatch_penalty);
for (k, &cg) in col_groups.iter().enumerate() {
if cg == row_group {
values.push(k as i32);
costs.push(self.cost_matrix[row_offset + k]);
}
}
}
}
let domain = CostFiniteDomain::new(values, costs);
row_var_ids.push(csp.add_variable(domain));
}
let mut unique_groups: Vec<u8> = row_groups.clone();
unique_groups.sort_unstable();
unique_groups.dedup();
for group in unique_groups {
let scope: Vec<u32> = (0..self.n_rows)
.filter(|&i| row_groups[i] == group)
.map(|i| row_var_ids[i])
.collect();
csp.add_constraint_enum(ConstraintEnum::AllDifferentExcept(
AllDifferentExcept::new(scope, SENTINEL),
));
}
csp.finalize();
let config = SolveConfig {
optimization_mode: OptimizationMode::MinimizeCost,
max_solutions: 1,
pruning: Pruning::AcFc,
node_budget: self.node_budget.or(Some(DEFAULT_NODE_BUDGET)),
..SolveConfig::default()
};
let solutions = csp.solve_optimized(&config);
let stats = csp.stats().clone();
let solution = match solutions.into_iter().next() {
Some(s) => s,
None => return Err(AssignmentError::Infeasible),
};
let mut assign: Vec<i32> = vec![SENTINEL; self.n_rows];
let mut cost: f64 = 0.0;
for i in 0..self.n_rows {
let v = solution[row_var_ids[i] as usize];
assign[i] = v;
if v == SENTINEL {
cost += self.unmatch_penalty;
} else {
cost += self.cost_matrix[i * self.n_cols + v as usize];
}
}
Ok(AssignmentSolution {
assign,
cost,
stats,
})
}
}