Skip to main content

csp_solver/builder/
assignment.rs

1//! Bipartite assignment COP builder.
2//!
3//! Fluent API for the common pattern of "assign N source rows to M
4//! target columns with per-cell costs, role-based AllDifferent groups,
5//! and optional hard pin constraints." Internally constructs a
6//! [`Csp<CostFiniteDomain>`] with one variable per row, an
7//! [`AllDifferentExcept`] per row-group, and `-1` as the unmatched
8//! sentinel; the underlying branch-and-bound search is invoked through
9//! [`Csp::solve_optimized`] with [`OptimizationMode::MinimizeCost`] and
10//! [`Pruning::AcFc`].
11//!
12//! `AssignmentBuilder` is intended for `n ≤ ~100` rows / cols. The
13//! branch-and-bound search degrades super-linearly past that point;
14//! larger problems should prefer a specialized Hungarian algorithm
15//! and feed the resulting permutation back into a Csp only if
16//! additional constraints (groups, pins) make the closed-form
17//! solution infeasible.
18//!
19//! # Example
20//!
21//! ```
22//! use csp_solver::assignment;
23//!
24//! let sol = assignment()
25//!     .rows(3)
26//!     .cols(3)
27//!     .cost(|i, k| if i == k { 0.0 } else { 10.0 })
28//!     .unmatch_penalty(100.0)
29//!     .solve()
30//!     .expect("solvable");
31//!
32//! assert_eq!(sol.assign, vec![0, 1, 2]);
33//! assert_eq!(sol.cost, 0.0);
34//! ```
35
36use crate::constraint::{AllDifferentExcept, ConstraintEnum};
37use crate::domain::CostFiniteDomain;
38use crate::{Csp, OptimizationMode, Pruning, SolveConfig, SolveStats};
39
40/// Sentinel value used in [`AssignmentSolution::assign`] to denote an
41/// unmatched row.
42///
43/// Encoded as a negative `i32` so it can never collide with a valid
44/// 0-indexed column. The internal `CostFiniteDomain` for each row
45/// always carries this value as a real domain entry priced at the
46/// caller-supplied [`AssignmentBuilder::unmatch_penalty`]; the
47/// branch-and-bound search treats it as just another option whose
48/// dominance is decided by total cost.
49pub const SENTINEL: i32 = -1;
50
51/// Default node budget applied to the underlying branch-and-bound
52/// search when the caller does not override it via
53/// [`AssignmentBuilder::node_budget`].
54const DEFAULT_NODE_BUDGET: u64 = 1_000_000;
55
56/// Fluent builder for bipartite assignment COPs.
57///
58/// Construct via [`assignment()`] (preferred) or [`Default::default`].
59/// All setters consume `self` and return `self`, allowing chained
60/// configuration. The terminal [`AssignmentBuilder::solve`] call
61/// validates the configuration, materializes the underlying
62/// [`Csp<CostFiniteDomain>`], runs branch-and-bound, and returns an
63/// [`AssignmentSolution`] (or an [`AssignmentError`] on
64/// mis-configuration / infeasibility).
65#[derive(Debug, Default)]
66pub struct AssignmentBuilder {
67    n_rows: usize,
68    n_cols: usize,
69    /// Row-major `n_rows × n_cols` matrix of per-cell costs. Populated
70    /// eagerly by [`AssignmentBuilder::cost`] so the builder owns no
71    /// closure state.
72    cost_matrix: Vec<f64>,
73    /// Length `n_rows`; defaults to all-zero (single group) if the
74    /// caller never invoked [`AssignmentBuilder::row_group`].
75    row_groups: Vec<u8>,
76    /// Length `n_cols`; defaults to all-zero (single group) if the
77    /// caller never invoked [`AssignmentBuilder::col_group`].
78    col_groups: Vec<u8>,
79    /// Hard `(row, col)` equality pins. Validated against the row's
80    /// computed domain at [`AssignmentBuilder::solve`] time.
81    pins: Vec<(usize, i32)>,
82    /// Per-row cost paid when the assigned column is [`SENTINEL`].
83    unmatch_penalty: f64,
84    /// Optional cap on branch-and-bound nodes; `None` means use the
85    /// crate default of `1_000_000`. See
86    /// [`crate::SolveConfig::node_budget`] for the contract.
87    node_budget: Option<u64>,
88    /// Tracks whether [`AssignmentBuilder::cost`] has been called so
89    /// `.solve()` can return [`AssignmentError::CostNotSet`] without
90    /// guessing from `cost_matrix.is_empty()`.
91    cost_set: bool,
92}
93
94/// Result of a successful [`AssignmentBuilder::solve`] call.
95#[derive(Debug, Clone)]
96pub struct AssignmentSolution {
97    /// Length `n_rows`. Each entry is the assigned column index in
98    /// `0..n_cols`, or [`SENTINEL`] (`-1`) if the row was left
99    /// unmatched.
100    pub assign: Vec<i32>,
101    /// Total cost of the assignment: the sum of `cost_matrix[i][k]`
102    /// for each matched row `i → k`, plus
103    /// [`AssignmentBuilder::unmatch_penalty`] for each unmatched row.
104    pub cost: f64,
105    /// Statistics from the underlying branch-and-bound run. Inspect
106    /// [`SolveStats::budget_exceeded`] to distinguish best-so-far
107    /// from optimal solutions.
108    pub stats: SolveStats,
109}
110
111/// Errors from [`AssignmentBuilder::solve`].
112#[derive(Debug)]
113pub enum AssignmentError {
114    /// `.rows()` or `.cols()` was not called before `.solve()` (or
115    /// either was set to zero).
116    DimensionsNotSet,
117    /// `.cost()` was not called before `.solve()`.
118    CostNotSet,
119    /// A custom `row_group` / `col_group` slice did not match the
120    /// declared dimensions.
121    GroupLengthMismatch,
122    /// A pin references an out-of-range row or a column that is
123    /// neither [`SENTINEL`] nor a valid `0..n_cols` index, or whose
124    /// row-group does not match its target column's group.
125    InvalidPin {
126        /// Row index supplied to [`AssignmentBuilder::pin`].
127        row: usize,
128        /// Column index (or [`SENTINEL`]) supplied to
129        /// [`AssignmentBuilder::pin`].
130        col: i32,
131    },
132    /// The CSP has no feasible solution under the supplied
133    /// constraints. Note that with [`SENTINEL`] always available a
134    /// pure assignment problem is always feasible; this variant
135    /// surfaces when pins or group constraints are mutually
136    /// incompatible.
137    Infeasible,
138}
139
140impl std::fmt::Display for AssignmentError {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        match self {
143            Self::DimensionsNotSet => {
144                write!(f, "AssignmentBuilder: .rows() and .cols() must both be set to a non-zero value before .solve()")
145            }
146            Self::CostNotSet => {
147                write!(f, "AssignmentBuilder: .cost() must be called before .solve()")
148            }
149            Self::GroupLengthMismatch => {
150                write!(f, "AssignmentBuilder: row_groups / col_groups length does not match the declared dimensions")
151            }
152            Self::InvalidPin { row, col } => {
153                write!(
154                    f,
155                    "AssignmentBuilder: invalid pin (row={row}, col={col}); col must be SENTINEL or a valid 0..n_cols index sharing the row's group"
156                )
157            }
158            Self::Infeasible => {
159                write!(f, "AssignmentBuilder: CSP is infeasible under the supplied constraints")
160            }
161        }
162    }
163}
164
165impl std::error::Error for AssignmentError {}
166
167/// Top-level constructor for an empty [`AssignmentBuilder`].
168///
169/// Equivalent to [`AssignmentBuilder::default`] but reads more
170/// naturally at the call site:
171///
172/// ```
173/// use csp_solver::assignment;
174///
175/// let sol = assignment()
176///     .rows(2)
177///     .cols(2)
178///     .cost(|i, k| (i + k) as f64)
179///     .solve()
180///     .expect("trivially solvable");
181/// assert_eq!(sol.assign.len(), 2);
182/// ```
183pub fn assignment() -> AssignmentBuilder {
184    AssignmentBuilder::default()
185}
186
187impl AssignmentBuilder {
188    /// Set the number of source rows.
189    pub fn rows(mut self, n: usize) -> Self {
190        self.n_rows = n;
191        self
192    }
193
194    /// Set the number of target columns.
195    pub fn cols(mut self, n: usize) -> Self {
196        self.n_cols = n;
197        self
198    }
199
200    /// Eagerly populate the row-major cost matrix.
201    ///
202    /// Calls `f(i, k)` exactly once per `(row, col)` cell during this
203    /// method, stores the result in an internal `Vec<f64>`, and
204    /// returns `self`. No closure is retained, which keeps the
205    /// builder `Send + Sync` even when constructed from non-`'static`
206    /// captures.
207    ///
208    /// # Panics
209    ///
210    /// Panics if [`AssignmentBuilder::rows`] or
211    /// [`AssignmentBuilder::cols`] has not been called yet — both
212    /// dimensions are required to know how to walk `f`.
213    pub fn cost(mut self, f: impl Fn(usize, usize) -> f64) -> Self {
214        assert!(
215            self.n_rows > 0 && self.n_cols > 0,
216            "AssignmentBuilder::cost() requires .rows() and .cols() to be set first"
217        );
218        let mut matrix = Vec::with_capacity(self.n_rows * self.n_cols);
219        for i in 0..self.n_rows {
220            for k in 0..self.n_cols {
221                matrix.push(f(i, k));
222            }
223        }
224        self.cost_matrix = matrix;
225        self.cost_set = true;
226        self
227    }
228
229    /// Tag each row with a `u8` group identifier.
230    ///
231    /// Rows in different groups are placed in independent
232    /// [`AllDifferentExcept`] scopes, and a row may only be assigned
233    /// to a column whose group identifier matches. Omitting the call
234    /// (or supplying `|_| 0`) puts every row in a single group, which
235    /// is the standard bipartite-assignment shape.
236    pub fn row_group(mut self, f: impl Fn(usize) -> u8) -> Self {
237        self.row_groups = (0..self.n_rows).map(f).collect();
238        self
239    }
240
241    /// Tag each column with a `u8` group identifier.
242    ///
243    /// See [`AssignmentBuilder::row_group`] for the semantics.
244    pub fn col_group(mut self, f: impl Fn(usize) -> u8) -> Self {
245        self.col_groups = (0..self.n_cols).map(f).collect();
246        self
247    }
248
249    /// Hard-pin row `row` to column `col`.
250    ///
251    /// `col` may be [`SENTINEL`] to force the row unmatched. Multiple
252    /// pins are accumulated; conflicting pins on the same row are
253    /// detected at [`AssignmentBuilder::solve`] time as
254    /// [`AssignmentError::Infeasible`].
255    pub fn pin(mut self, row: usize, col: i32) -> Self {
256        self.pins.push((row, col));
257        self
258    }
259
260    /// Set the per-row cost paid when a row is assigned to
261    /// [`SENTINEL`] (left unmatched).
262    pub fn unmatch_penalty(mut self, penalty: f64) -> Self {
263        self.unmatch_penalty = penalty;
264        self
265    }
266
267    /// Override the underlying branch-and-bound node budget.
268    ///
269    /// Passing `None` here is *not* the same as never calling this
270    /// method: `None` requests an unbounded search, while the default
271    /// (no call) installs a `1_000_000` node guard so a pathological
272    /// problem cannot hang the caller. See
273    /// [`crate::SolveConfig::node_budget`].
274    pub fn node_budget(mut self, budget: Option<u64>) -> Self {
275        self.node_budget = budget;
276        self
277    }
278
279    /// Validate the configuration, build the underlying CSP, and run
280    /// branch-and-bound to find the minimum-cost assignment.
281    pub fn solve(self) -> Result<AssignmentSolution, AssignmentError> {
282        // 1. Dimensions + cost must be set.
283        if self.n_rows == 0 || self.n_cols == 0 {
284            return Err(AssignmentError::DimensionsNotSet);
285        }
286        if !self.cost_set {
287            return Err(AssignmentError::CostNotSet);
288        }
289
290        // 2. Default groups to all-zero if the caller did not supply
291        //    them; otherwise verify lengths match the declared
292        //    dimensions.
293        let row_groups: Vec<u8> = if self.row_groups.is_empty() {
294            vec![0; self.n_rows]
295        } else if self.row_groups.len() == self.n_rows {
296            self.row_groups
297        } else {
298            return Err(AssignmentError::GroupLengthMismatch);
299        };
300        let col_groups: Vec<u8> = if self.col_groups.is_empty() {
301            vec![0; self.n_cols]
302        } else if self.col_groups.len() == self.n_cols {
303            self.col_groups
304        } else {
305            return Err(AssignmentError::GroupLengthMismatch);
306        };
307
308        // 3. Pre-validate pins and collapse them into a per-row map.
309        //    Pins are baked directly into each row's CostFiniteDomain
310        //    at construction time so the variable's `original_domain`
311        //    already encodes the singleton; this matters because
312        //    `Csp::solve_optimized` calls `Variable::reset()` at
313        //    search start and would otherwise undo any post-hoc
314        //    domain mutation. Multiple pins on the same row are
315        //    accepted only if they agree.
316        let mut row_pin: Vec<Option<i32>> = vec![None; self.n_rows];
317        for &(row, col) in &self.pins {
318            if row >= self.n_rows {
319                return Err(AssignmentError::InvalidPin { row, col });
320            }
321            if col != SENTINEL && (col < 0 || col as usize >= self.n_cols) {
322                return Err(AssignmentError::InvalidPin { row, col });
323            }
324            // Verify pin is compatible with the row's group: SENTINEL
325            // is always allowed, otherwise the column's group must
326            // match the row's.
327            if col != SENTINEL && col_groups[col as usize] != row_groups[row] {
328                return Err(AssignmentError::InvalidPin { row, col });
329            }
330            match row_pin[row] {
331                None => row_pin[row] = Some(col),
332                Some(prev) if prev == col => {} // duplicate, fine
333                Some(_) => return Err(AssignmentError::Infeasible),
334            }
335        }
336
337        // 4. Build one CostFiniteDomain per row, restricted to columns
338        //    whose group matches the row's group (and to the pinned
339        //    singleton when a pin is present). SENTINEL is always
340        //    available at the unmatch penalty unless overridden by a
341        //    non-SENTINEL pin.
342        let mut csp: Csp<CostFiniteDomain> = Csp::new();
343        let mut row_var_ids: Vec<u32> = Vec::with_capacity(self.n_rows);
344
345        for i in 0..self.n_rows {
346            let row_group = row_groups[i];
347            let row_offset = i * self.n_cols;
348
349            let mut values: Vec<i32> = Vec::with_capacity(self.n_cols + 1);
350            let mut costs: Vec<f64> = Vec::with_capacity(self.n_cols + 1);
351
352            match row_pin[i] {
353                Some(SENTINEL) => {
354                    values.push(SENTINEL);
355                    costs.push(self.unmatch_penalty);
356                }
357                Some(col) => {
358                    // col is guaranteed in 0..n_cols and group-compatible
359                    // by the pin validation above.
360                    values.push(col);
361                    costs.push(self.cost_matrix[row_offset + col as usize]);
362                }
363                None => {
364                    // SENTINEL first; CostFiniteDomain canonicalises to
365                    // ascending value order internally so the order at
366                    // construction is irrelevant for correctness, but
367                    // starting from SENTINEL keeps the (values, costs)
368                    // slices easy to read in a debugger.
369                    values.push(SENTINEL);
370                    costs.push(self.unmatch_penalty);
371                    for (k, &cg) in col_groups.iter().enumerate() {
372                        if cg == row_group {
373                            values.push(k as i32);
374                            costs.push(self.cost_matrix[row_offset + k]);
375                        }
376                    }
377                }
378            }
379
380            let domain = CostFiniteDomain::new(values, costs);
381            row_var_ids.push(csp.add_variable(domain));
382        }
383
384        // 5. Add one AllDifferentExcept per distinct row group.
385        let mut unique_groups: Vec<u8> = row_groups.clone();
386        unique_groups.sort_unstable();
387        unique_groups.dedup();
388        for group in unique_groups {
389            let scope: Vec<u32> = (0..self.n_rows)
390                .filter(|&i| row_groups[i] == group)
391                .map(|i| row_var_ids[i])
392                .collect();
393            // A single-row group still benefits from the constraint
394            // for symmetry — it's a no-op at search time but keeps
395            // the adjacency structure uniform across groups.
396            csp.add_constraint_enum(ConstraintEnum::AllDifferentExcept(
397                AllDifferentExcept::new(scope, SENTINEL),
398            ));
399        }
400
401        // 6. Finalize and run branch-and-bound.
402        csp.finalize();
403
404        let config = SolveConfig {
405            optimization_mode: OptimizationMode::MinimizeCost,
406            max_solutions: 1,
407            pruning: Pruning::AcFc,
408            node_budget: self.node_budget.or(Some(DEFAULT_NODE_BUDGET)),
409            ..SolveConfig::default()
410        };
411
412        let solutions = csp.solve_optimized(&config);
413        let stats = csp.stats().clone();
414
415        let solution = match solutions.into_iter().next() {
416            Some(s) => s,
417            None => return Err(AssignmentError::Infeasible),
418        };
419
420        // 7. Project the Solution<CostFiniteDomain> back into the
421        //    row-indexed `assign` vector and recompute the total cost
422        //    from the cost matrix + unmatch penalty so callers see a
423        //    value that matches their inputs exactly (as opposed to
424        //    the search's running total, which can drift through
425        //    floating-point summation order).
426        let mut assign: Vec<i32> = vec![SENTINEL; self.n_rows];
427        let mut cost: f64 = 0.0;
428        for i in 0..self.n_rows {
429            let v = solution[row_var_ids[i] as usize];
430            assign[i] = v;
431            if v == SENTINEL {
432                cost += self.unmatch_penalty;
433            } else {
434                cost += self.cost_matrix[i * self.n_cols + v as usize];
435            }
436        }
437
438        Ok(AssignmentSolution {
439            assign,
440            cost,
441            stats,
442        })
443    }
444}