Skip to main content

csp_solver/solver/
optimize.rs

1//! Branch-and-bound optimization search.
2//!
3//! Extends backtracking with cost tracking: at each search node, computes a
4//! lower bound on the total cost of any completion. Prunes when the bound
5//! exceeds the incumbent solution's cost.
6//!
7//! # Deferred extensions (Phase 2 design decisions)
8//!
9//! **`TieredCostEval` / lazy cost evaluation** — deferred because the
10//! [`CostFiniteDomain::min_cost`] incremental `Cell` cache (landed in
11//! Phase 2A) covers the immediate bottleneck: `optimistic_bound` was
12//! calling `min_cost()` O(domain) per node, and the cache amortizes that
13//! to O(1). A tiered evaluator (cheap proxy lower bound + expensive actual
14//! cost with memoization) only pays its API cost when a second consumer
15//! with a genuinely non-trivial cost function appears. The existing
16//! [`DomainCostEval`] trait is the extension point — implement a new
17//! evaluator type, no solver-core changes needed.
18//!
19//! **`solve_with_warm_start`** — deferred because no consumer currently
20//! demands incremental re-solving (accepting a previous solution as the
21//! initial incumbent for tighter B&B pruning from node zero). The
22//! implementation is ~20 LOC: add a `warm_start: Option<&Solution>` field
23//! to [`OptimizeConfig`] and seed `best_cost` / `best_solution` before
24//! the first `bb_recurse` call. Land it when the first incremental
25//! consumer (e.g., a playground animation scrubber) appears.
26//!
27//! **Unified `Constraint` trait** — the current `ConstraintEnum` dispatch
28//! and separate `SoftConstraint` trait are adequate for the constraint
29//! vocabulary in use. Collapsing them into a single trait with a default
30//! `penalty() -> f64::INFINITY` method is a readability improvement, not a
31//! performance one (the enum dispatch is already inlined by the compiler).
32//! Land when the constraint vocabulary grows enough to justify the
33//! refactor.
34//!
35//! **`tracing` instrumentation** — adding `#[instrument]` spans on
36//! `branch_and_bound`, `propagate`, `ac3_from_variable` would enable
37//! production diagnostics (measure propagation time, identify bottleneck
38//! constraints). Deferred as polish — the `SolveStats` struct already
39//! tracks `nodes_explored`, `backtracks`, `propagations`, and
40//! `budget_exceeded`, which covers the coarse-grained profiling needs.
41
42use crate::constraint::{ConstraintEnum, VarId};
43use crate::domain::Domain;
44use crate::ordering::{self, Ordering};
45use crate::solver::ac3;
46use crate::solver::backtrack::Solution;
47use crate::solver::propagate;
48use crate::solver::SearchContext;
49use crate::variable::Variable;
50use crate::Pruning;
51
52/// Configuration for branch-and-bound optimization.
53pub struct OptimizeConfig {
54    pub pruning: Pruning,
55    pub ordering: Ordering,
56    pub max_solutions: usize,
57    pub constraint_weights: Vec<f64>,
58    pub var_constraint_ids: Vec<Vec<usize>>,
59    /// If true, maximize cost instead of minimize.
60    pub maximize: bool,
61    /// Maximum number of search nodes before aborting early.
62    /// See [`crate::SolveConfig::node_budget`].
63    pub node_budget: Option<u64>,
64}
65
66/// Cost evaluator for domains. Passed into the optimizer so that the same
67/// search code works for both `CostDomain` and plain `Domain` (zero cost).
68pub trait DomainCostEval<D: Domain> {
69    /// Cost of assigning `val` to the variable whose current domain is `domain`.
70    fn cost(&self, domain: &D, val: &D::Value) -> f64;
71    /// Lower bound on the minimum cost achievable from `domain`.
72    fn min_cost(&self, domain: &D) -> f64;
73    /// Upper bound on the maximum cost achievable from `domain`.
74    fn max_cost(&self, domain: &D) -> f64;
75}
76
77/// No-op evaluator: all costs are zero. Used when D doesn't implement CostDomain.
78pub struct ZeroCost;
79
80impl<D: Domain> DomainCostEval<D> for ZeroCost {
81    #[inline]
82    fn cost(&self, _domain: &D, _val: &D::Value) -> f64 {
83        0.0
84    }
85    #[inline]
86    fn min_cost(&self, _domain: &D) -> f64 {
87        0.0
88    }
89    #[inline]
90    fn max_cost(&self, _domain: &D) -> f64 {
91        0.0
92    }
93}
94
95/// Evaluator that delegates to CostDomain methods.
96pub struct CostDomainEval;
97
98impl<D: crate::domain::CostDomain> DomainCostEval<D> for CostDomainEval {
99    #[inline]
100    fn cost(&self, domain: &D, val: &D::Value) -> f64 {
101        domain.cost(val)
102    }
103    #[inline]
104    fn min_cost(&self, domain: &D) -> f64 {
105        domain.min_cost()
106    }
107    #[inline]
108    fn max_cost(&self, domain: &D) -> f64 {
109        domain
110            .values()
111            .into_iter()
112            .map(|v| domain.cost(&v))
113            .fold(f64::NEG_INFINITY, f64::max)
114    }
115}
116
117/// A scored solution: the assignment together with its total cost.
118struct ScoredSolution<D: Domain> {
119    solution: Solution<D>,
120    cost: f64,
121}
122
123/// Mutable accumulator for branch-and-bound: scored solutions found so far
124/// and the best (effective) cost seen, threaded through the recursive search
125/// to avoid per-function argument bloat.
126struct BranchBoundState<D: Domain> {
127    scored: Vec<ScoredSolution<D>>,
128    best_cost: f64,
129}
130
131/// Run branch-and-bound search. Returns up to `max_solutions` solutions,
132/// sorted by cost (best first according to the optimization direction).
133pub fn branch_and_bound<D: Domain>(
134    variables: &mut [Variable<D>],
135    constraints: &[ConstraintEnum<D>],
136    adjacency: &crate::adjacency::Adjacency,
137    config: &OptimizeConfig,
138    stats: &mut crate::SolveStats,
139    cost_eval: &dyn DomainCostEval<D>,
140) -> Vec<Solution<D>>
141where
142    D::Value: PartialEq,
143{
144    let num_vars = variables.len();
145    let mut assignment: Vec<Option<D::Value>> = vec![None; num_vars];
146    let mut stack: Vec<VarId> = (0..num_vars as u32).collect();
147    let mut ctx = SearchContext { variables, constraints, adjacency, stats };
148    let mut bb = BranchBoundState {
149        scored: Vec::new(),
150        best_cost: f64::INFINITY,
151    };
152
153    bb_recurse(
154        &mut ctx,
155        config,
156        cost_eval,
157        &mut assignment,
158        &mut stack,
159        &mut bb,
160        0,
161    );
162
163    // Sort by cost: best first (lowest for minimize, highest for maximize).
164    if config.maximize {
165        bb.scored.sort_by(|a, b| b.cost.partial_cmp(&a.cost).unwrap_or(std::cmp::Ordering::Equal));
166    } else {
167        bb.scored.sort_by(|a, b| a.cost.partial_cmp(&b.cost).unwrap_or(std::cmp::Ordering::Equal));
168    }
169
170    // Keep only the best `max_solutions`.
171    bb.scored.truncate(config.max_solutions);
172
173    bb.scored.into_iter().map(|s| s.solution).collect()
174}
175
176/// Compute the cost of a complete assignment.
177fn assignment_cost<D: Domain>(
178    assignment: &[Option<D::Value>],
179    variables: &[Variable<D>],
180    constraints: &[ConstraintEnum<D>],
181    cost_eval: &dyn DomainCostEval<D>,
182) -> f64
183where
184    D::Value: PartialEq,
185{
186    let mut cost = 0.0;
187
188    // Domain costs.
189    for (i, val) in assignment.iter().enumerate() {
190        if let Some(v) = val {
191            cost += cost_eval.cost(&variables[i].domain, v);
192        }
193    }
194
195    // Soft constraint penalties.
196    for c in constraints {
197        cost += c.soft_penalty(assignment);
198    }
199
200    cost
201}
202
203/// Compute the optimistic bound on the cost of any completion.
204///
205/// For minimize: returns a lower bound (assigned vars use actual cost,
206/// unassigned use min_cost).
207/// For maximize: returns an upper bound (assigned vars use actual cost,
208/// unassigned use max_cost). This is then negated by the caller to
209/// compare against the negated incumbent.
210fn optimistic_bound<D: Domain>(
211    assignment: &[Option<D::Value>],
212    variables: &[Variable<D>],
213    constraints: &[ConstraintEnum<D>],
214    cost_eval: &dyn DomainCostEval<D>,
215    maximize: bool,
216) -> f64
217where
218    D::Value: PartialEq,
219{
220    let mut bound = 0.0;
221
222    for (i, val) in assignment.iter().enumerate() {
223        match val {
224            Some(v) => bound += cost_eval.cost(&variables[i].domain, v),
225            None => {
226                if maximize {
227                    bound += cost_eval.max_cost(&variables[i].domain);
228                } else {
229                    bound += cost_eval.min_cost(&variables[i].domain);
230                }
231            }
232        }
233    }
234
235    // Soft constraint penalties for fully-assigned scopes.
236    // (Partially-assigned scopes contribute 0 optimistically.)
237    for c in constraints {
238        let scope = c.scope();
239        if scope.iter().all(|&v| assignment[v as usize].is_some()) {
240            bound += c.soft_penalty(assignment);
241        }
242    }
243
244    bound
245}
246
247fn bb_recurse<D: Domain>(
248    ctx: &mut SearchContext<'_, D>,
249    config: &OptimizeConfig,
250    cost_eval: &dyn DomainCostEval<D>,
251    assignment: &mut Vec<Option<D::Value>>,
252    stack: &mut Vec<VarId>,
253    bb: &mut BranchBoundState<D>,
254    depth: usize,
255) -> bool
256where
257    D::Value: PartialEq,
258{
259    // Complete assignment — record solution.
260    if stack.is_empty() {
261        let cost = assignment_cost(assignment, ctx.variables, ctx.constraints, cost_eval);
262        let effective_cost = if config.maximize { -cost } else { cost };
263
264        if effective_cost < bb.best_cost {
265            bb.best_cost = effective_cost;
266        }
267
268        let sol: Solution<D> = assignment
269            .iter()
270            .map(|v| v.as_ref().unwrap().clone())
271            .collect();
272        bb.scored.push(ScoredSolution { solution: sol, cost });
273
274        // For optimization, keep searching for better solutions.
275        return false;
276    }
277
278    // Budget guard: abort early if the search has exceeded its node
279    // budget. Return `true` so the recursion unwinds cleanly; whatever
280    // scored solutions have been found so far remain in `scored` and
281    // are surfaced at the end of `branch_and_bound`. `budget_exceeded`
282    // is set so callers can distinguish best-so-far from optimal.
283    // Checked before `nodes_explored += 1` so the post-budget node is
284    // never counted and the flag is set exactly once per search.
285    if let Some(budget) = config.node_budget
286        && ctx.stats.nodes_explored >= budget
287    {
288        ctx.stats.budget_exceeded = true;
289        return true;
290    }
291
292    ctx.stats.nodes_explored += 1;
293
294    // Bound check: prune if the optimistic bound can't beat the incumbent.
295    let ob = optimistic_bound(
296        assignment, ctx.variables, ctx.constraints, cost_eval, config.maximize,
297    );
298    let effective_ob = if config.maximize { -ob } else { ob };
299    if effective_ob >= bb.best_cost {
300        return false;
301    }
302
303    let idx = ordering::select_variable(
304        stack,
305        ctx.variables,
306        config.ordering,
307        &config.constraint_weights,
308        &config.var_constraint_ids,
309    )
310    .unwrap();
311
312    let var = stack.swap_remove(idx);
313
314    // Value ordering: sort by cost (lowest first for minimize, highest for maximize).
315    let mut values: Vec<_> = ctx.variables[var as usize].domain.iter().collect();
316    {
317        let domain = &ctx.variables[var as usize].domain;
318        if config.maximize {
319            values.sort_by(|a, b| {
320                let ca = cost_eval.cost(domain, b);
321                let cb = cost_eval.cost(domain, a);
322                ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
323            });
324        } else {
325            values.sort_by(|a, b| {
326                let ca = cost_eval.cost(domain, a);
327                let cb = cost_eval.cost(domain, b);
328                ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
329            });
330        }
331    }
332
333    for val in values {
334        assignment[var as usize] = Some(val.clone());
335        ctx.variables[var as usize].restrict_to(&val, depth);
336
337        let mut valid = true;
338        for &ci in ctx.adjacency.constraints_for(var) {
339            let ci = ci as usize;
340            let scope = ctx.constraints[ci].scope();
341            if scope.iter().all(|&v| assignment[v as usize].is_some())
342                && !ctx.constraints[ci].check(assignment)
343            {
344                valid = false;
345                break;
346            }
347        }
348
349        if valid {
350            let dwo = match config.pruning {
351                Pruning::None => false,
352                Pruning::ForwardChecking => propagate::forward_check(
353                    var,
354                    ctx.variables,
355                    ctx.constraints,
356                    ctx.adjacency,
357                    assignment.as_mut_slice(),
358                    ctx.stats,
359                    depth,
360                ),
361                Pruning::Ac3 => ac3::ac3_from_variable(
362                    var, ctx.variables, ctx.constraints, ctx.adjacency, assignment, ctx.stats, depth,
363                ),
364                Pruning::AcFc => propagate::ac_fc(
365                    var,
366                    ctx.variables,
367                    ctx.constraints,
368                    ctx.adjacency,
369                    assignment.as_mut_slice(),
370                    ctx.stats,
371                    depth,
372                ),
373            };
374
375            if !dwo
376                && bb_recurse(
377                    ctx,
378                    config,
379                    cost_eval,
380                    assignment,
381                    stack,
382                    bb,
383                    depth + 1,
384                )
385            {
386                return true;
387            }
388        }
389
390        ctx.stats.backtracks += 1;
391        assignment[var as usize] = None;
392        for v in ctx.variables.iter_mut() {
393            v.restore(depth);
394        }
395    }
396
397    stack.push(var);
398    false
399}