1use crate::constraint::{ConstraintEnum, VarId};
43use crate::domain::Domain;
44use crate::ordering::{self, Ordering};
45use crate::solver::ac3;
46use crate::solver::backtrack::Solution;
47use crate::solver::propagate;
48use crate::solver::SearchContext;
49use crate::variable::Variable;
50use crate::Pruning;
51
52pub struct OptimizeConfig {
54 pub pruning: Pruning,
55 pub ordering: Ordering,
56 pub max_solutions: usize,
57 pub constraint_weights: Vec<f64>,
58 pub var_constraint_ids: Vec<Vec<usize>>,
59 pub maximize: bool,
61 pub node_budget: Option<u64>,
64}
65
66pub trait DomainCostEval<D: Domain> {
69 fn cost(&self, domain: &D, val: &D::Value) -> f64;
71 fn min_cost(&self, domain: &D) -> f64;
73 fn max_cost(&self, domain: &D) -> f64;
75}
76
77pub struct ZeroCost;
79
80impl<D: Domain> DomainCostEval<D> for ZeroCost {
81 #[inline]
82 fn cost(&self, _domain: &D, _val: &D::Value) -> f64 {
83 0.0
84 }
85 #[inline]
86 fn min_cost(&self, _domain: &D) -> f64 {
87 0.0
88 }
89 #[inline]
90 fn max_cost(&self, _domain: &D) -> f64 {
91 0.0
92 }
93}
94
95pub struct CostDomainEval;
97
98impl<D: crate::domain::CostDomain> DomainCostEval<D> for CostDomainEval {
99 #[inline]
100 fn cost(&self, domain: &D, val: &D::Value) -> f64 {
101 domain.cost(val)
102 }
103 #[inline]
104 fn min_cost(&self, domain: &D) -> f64 {
105 domain.min_cost()
106 }
107 #[inline]
108 fn max_cost(&self, domain: &D) -> f64 {
109 domain
110 .values()
111 .into_iter()
112 .map(|v| domain.cost(&v))
113 .fold(f64::NEG_INFINITY, f64::max)
114 }
115}
116
117struct ScoredSolution<D: Domain> {
119 solution: Solution<D>,
120 cost: f64,
121}
122
123struct BranchBoundState<D: Domain> {
127 scored: Vec<ScoredSolution<D>>,
128 best_cost: f64,
129}
130
131pub fn branch_and_bound<D: Domain>(
134 variables: &mut [Variable<D>],
135 constraints: &[ConstraintEnum<D>],
136 adjacency: &crate::adjacency::Adjacency,
137 config: &OptimizeConfig,
138 stats: &mut crate::SolveStats,
139 cost_eval: &dyn DomainCostEval<D>,
140) -> Vec<Solution<D>>
141where
142 D::Value: PartialEq,
143{
144 let num_vars = variables.len();
145 let mut assignment: Vec<Option<D::Value>> = vec![None; num_vars];
146 let mut stack: Vec<VarId> = (0..num_vars as u32).collect();
147 let mut ctx = SearchContext { variables, constraints, adjacency, stats };
148 let mut bb = BranchBoundState {
149 scored: Vec::new(),
150 best_cost: f64::INFINITY,
151 };
152
153 bb_recurse(
154 &mut ctx,
155 config,
156 cost_eval,
157 &mut assignment,
158 &mut stack,
159 &mut bb,
160 0,
161 );
162
163 if config.maximize {
165 bb.scored.sort_by(|a, b| b.cost.partial_cmp(&a.cost).unwrap_or(std::cmp::Ordering::Equal));
166 } else {
167 bb.scored.sort_by(|a, b| a.cost.partial_cmp(&b.cost).unwrap_or(std::cmp::Ordering::Equal));
168 }
169
170 bb.scored.truncate(config.max_solutions);
172
173 bb.scored.into_iter().map(|s| s.solution).collect()
174}
175
176fn assignment_cost<D: Domain>(
178 assignment: &[Option<D::Value>],
179 variables: &[Variable<D>],
180 constraints: &[ConstraintEnum<D>],
181 cost_eval: &dyn DomainCostEval<D>,
182) -> f64
183where
184 D::Value: PartialEq,
185{
186 let mut cost = 0.0;
187
188 for (i, val) in assignment.iter().enumerate() {
190 if let Some(v) = val {
191 cost += cost_eval.cost(&variables[i].domain, v);
192 }
193 }
194
195 for c in constraints {
197 cost += c.soft_penalty(assignment);
198 }
199
200 cost
201}
202
203fn optimistic_bound<D: Domain>(
211 assignment: &[Option<D::Value>],
212 variables: &[Variable<D>],
213 constraints: &[ConstraintEnum<D>],
214 cost_eval: &dyn DomainCostEval<D>,
215 maximize: bool,
216) -> f64
217where
218 D::Value: PartialEq,
219{
220 let mut bound = 0.0;
221
222 for (i, val) in assignment.iter().enumerate() {
223 match val {
224 Some(v) => bound += cost_eval.cost(&variables[i].domain, v),
225 None => {
226 if maximize {
227 bound += cost_eval.max_cost(&variables[i].domain);
228 } else {
229 bound += cost_eval.min_cost(&variables[i].domain);
230 }
231 }
232 }
233 }
234
235 for c in constraints {
238 let scope = c.scope();
239 if scope.iter().all(|&v| assignment[v as usize].is_some()) {
240 bound += c.soft_penalty(assignment);
241 }
242 }
243
244 bound
245}
246
247fn bb_recurse<D: Domain>(
248 ctx: &mut SearchContext<'_, D>,
249 config: &OptimizeConfig,
250 cost_eval: &dyn DomainCostEval<D>,
251 assignment: &mut Vec<Option<D::Value>>,
252 stack: &mut Vec<VarId>,
253 bb: &mut BranchBoundState<D>,
254 depth: usize,
255) -> bool
256where
257 D::Value: PartialEq,
258{
259 if stack.is_empty() {
261 let cost = assignment_cost(assignment, ctx.variables, ctx.constraints, cost_eval);
262 let effective_cost = if config.maximize { -cost } else { cost };
263
264 if effective_cost < bb.best_cost {
265 bb.best_cost = effective_cost;
266 }
267
268 let sol: Solution<D> = assignment
269 .iter()
270 .map(|v| v.as_ref().unwrap().clone())
271 .collect();
272 bb.scored.push(ScoredSolution { solution: sol, cost });
273
274 return false;
276 }
277
278 if let Some(budget) = config.node_budget
286 && ctx.stats.nodes_explored >= budget
287 {
288 ctx.stats.budget_exceeded = true;
289 return true;
290 }
291
292 ctx.stats.nodes_explored += 1;
293
294 let ob = optimistic_bound(
296 assignment, ctx.variables, ctx.constraints, cost_eval, config.maximize,
297 );
298 let effective_ob = if config.maximize { -ob } else { ob };
299 if effective_ob >= bb.best_cost {
300 return false;
301 }
302
303 let idx = ordering::select_variable(
304 stack,
305 ctx.variables,
306 config.ordering,
307 &config.constraint_weights,
308 &config.var_constraint_ids,
309 )
310 .unwrap();
311
312 let var = stack.swap_remove(idx);
313
314 let mut values: Vec<_> = ctx.variables[var as usize].domain.iter().collect();
316 {
317 let domain = &ctx.variables[var as usize].domain;
318 if config.maximize {
319 values.sort_by(|a, b| {
320 let ca = cost_eval.cost(domain, b);
321 let cb = cost_eval.cost(domain, a);
322 ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
323 });
324 } else {
325 values.sort_by(|a, b| {
326 let ca = cost_eval.cost(domain, a);
327 let cb = cost_eval.cost(domain, b);
328 ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
329 });
330 }
331 }
332
333 for val in values {
334 assignment[var as usize] = Some(val.clone());
335 ctx.variables[var as usize].restrict_to(&val, depth);
336
337 let mut valid = true;
338 for &ci in ctx.adjacency.constraints_for(var) {
339 let ci = ci as usize;
340 let scope = ctx.constraints[ci].scope();
341 if scope.iter().all(|&v| assignment[v as usize].is_some())
342 && !ctx.constraints[ci].check(assignment)
343 {
344 valid = false;
345 break;
346 }
347 }
348
349 if valid {
350 let dwo = match config.pruning {
351 Pruning::None => false,
352 Pruning::ForwardChecking => propagate::forward_check(
353 var,
354 ctx.variables,
355 ctx.constraints,
356 ctx.adjacency,
357 assignment.as_mut_slice(),
358 ctx.stats,
359 depth,
360 ),
361 Pruning::Ac3 => ac3::ac3_from_variable(
362 var, ctx.variables, ctx.constraints, ctx.adjacency, assignment, ctx.stats, depth,
363 ),
364 Pruning::AcFc => propagate::ac_fc(
365 var,
366 ctx.variables,
367 ctx.constraints,
368 ctx.adjacency,
369 assignment.as_mut_slice(),
370 ctx.stats,
371 depth,
372 ),
373 };
374
375 if !dwo
376 && bb_recurse(
377 ctx,
378 config,
379 cost_eval,
380 assignment,
381 stack,
382 bb,
383 depth + 1,
384 )
385 {
386 return true;
387 }
388 }
389
390 ctx.stats.backtracks += 1;
391 assignment[var as usize] = None;
392 for v in ctx.variables.iter_mut() {
393 v.restore(depth);
394 }
395 }
396
397 stack.push(var);
398 false
399}