use crate::options::{MultiStartConfig, QpWarmStart, SolverOptions, StartStrategy};
use crate::problem::{SolveStatus, SolverResult};
use crate::qp::problem::QpProblem;
use std::sync::Arc;
use std::time::{Duration, Instant};
use rayon::prelude::*;
type ThreadPoolFactory = Option<
Box<dyn Fn(usize) -> Result<rayon::ThreadPool, rayon::ThreadPoolBuildError> + Send + Sync>,
>;
const LCG_A: u64 = 1_664_525;
const LCG_C: u64 = 1_013_904_223;
const LCG_M_MASK: u64 = 0xFFFF_FFFF;
pub(crate) const MULTISTART_UNBOUNDED_RANGE: f64 = 10.0;
fn lcg_next(state: &mut u64) -> u64 {
*state = (state.wrapping_mul(LCG_A).wrapping_add(LCG_C)) & LCG_M_MASK;
*state
}
fn lcg_uniform_01(state: &mut u64) -> f64 {
(lcg_next(state) as f64) / (u32::MAX as f64 + 1.0)
}
fn sample_random_box(state: &mut u64, bounds: &[(f64, f64)]) -> Vec<f64> {
bounds
.iter()
.map(|&(lb, ub)| {
let lo = if lb.is_finite() {
lb.max(-MULTISTART_UNBOUNDED_RANGE)
} else {
-MULTISTART_UNBOUNDED_RANGE
};
let hi = if ub.is_finite() {
ub.min(MULTISTART_UNBOUNDED_RANGE)
} else {
MULTISTART_UNBOUNDED_RANGE
};
if hi <= lo {
return lo;
}
let u = lcg_uniform_01(state);
lo + u * (hi - lo)
})
.collect()
}
fn latin_hypercube(seed: u64, n_starts: usize, bounds: &[(f64, f64)]) -> Vec<Vec<f64>> {
let n = bounds.len();
if n == 0 || n_starts == 0 {
return Vec::new();
}
let mut state = seed.wrapping_add(0xA5A5_5A5A_5A5A_A5A5);
if state == 0 {
state = 1;
}
let perms: Vec<Vec<usize>> = (0..n)
.map(|_| {
let mut p: Vec<usize> = (0..n_starts).collect();
for i in (1..n_starts).rev() {
let j = (lcg_next(&mut state) as usize) % (i + 1);
p.swap(i, j);
}
p
})
.collect();
(0..n_starts)
.map(|s| {
(0..n)
.map(|j| {
let (lb, ub) = bounds[j];
let lo = if lb.is_finite() {
lb.max(-MULTISTART_UNBOUNDED_RANGE)
} else {
-MULTISTART_UNBOUNDED_RANGE
};
let hi = if ub.is_finite() {
ub.min(MULTISTART_UNBOUNDED_RANGE)
} else {
MULTISTART_UNBOUNDED_RANGE
};
if hi <= lo {
return lo;
}
let stratum = perms[j][s];
let u = lcg_uniform_01(&mut state);
let frac = (stratum as f64 + u) / n_starts as f64;
lo + frac * (hi - lo)
})
.collect()
})
.collect()
}
fn status_rank(s: &SolveStatus) -> u8 {
use SolveStatus::*;
match s {
Optimal => 0,
NonconvexGlobal => 0,
LocallyOptimal => 1,
NonconvexLocal => 1,
SuboptimalSolution => 2,
MaxIterations => 3,
Timeout => 4,
NumericalError => 5,
NonConvex(_) => 6,
Unbounded => 7,
Infeasible => 8,
NotSupported(_) => 9,
}
}
fn pick_better(a: SolverResult, b: SolverResult) -> SolverResult {
let ra = status_rank(&a.status);
let rb = status_rank(&b.status);
match ra.cmp(&rb) {
std::cmp::Ordering::Less => a,
std::cmp::Ordering::Greater => b,
std::cmp::Ordering::Equal => match (a.objective.is_finite(), b.objective.is_finite()) {
(true, true) => {
if a.objective <= b.objective {
a
} else {
b
}
}
(true, false) => a,
(false, true) => b,
(false, false) => a,
},
}
}
fn solve_one(
problem: &QpProblem,
base_opts: &SolverOptions,
warm: Option<QpWarmStart>,
) -> SolverResult {
let mut opts = base_opts.clone();
opts.warm_start_qp = warm;
opts.multistart = None;
opts.threads = 1;
crate::qp::solve_qp_with(problem, &opts)
}
fn build_random_starts(config: &MultiStartConfig, bounds: &[(f64, f64)]) -> Vec<Vec<f64>> {
let extra = config.n_starts.saturating_sub(1);
if extra == 0 {
return Vec::new();
}
let seed = if config.seed == 0 { 1 } else { config.seed };
match config.strategy {
StartStrategy::RandomBox => {
let mut state = seed;
(0..extra)
.map(|_| sample_random_box(&mut state, bounds))
.collect()
}
StartStrategy::LatinHypercube => {
let all = latin_hypercube(seed, config.n_starts, bounds);
all.into_iter().skip(1).collect()
}
}
}
pub(crate) struct MultiStartHooks {
pub on_solve_enter: Arc<dyn Fn() + Send + Sync>,
pub on_solve_exit: Arc<dyn Fn() + Send + Sync>,
pub disable_deadline_shortcut: bool,
pub thread_pool_factory: ThreadPoolFactory,
}
pub fn solve_qp_multistart(
problem: &QpProblem,
options: &SolverOptions,
config: &MultiStartConfig,
) -> SolverResult {
if options.validate().is_err() {
return SolverResult::numerical_error();
}
solve_qp_multistart_with_hooks(problem, options, config, None)
}
pub(crate) fn solve_qp_multistart_with_hooks(
problem: &QpProblem,
options: &SolverOptions,
config: &MultiStartConfig,
hooks: Option<&MultiStartHooks>,
) -> SolverResult {
if config.n_starts <= 1 {
return solve_one(problem, options, None);
}
let mut shared_opts = options.clone();
if shared_opts.deadline.is_none() {
if let Some(secs) = shared_opts.timeout_secs {
shared_opts.deadline = Some(Instant::now() + Duration::from_secs_f64(secs));
}
}
shared_opts.timeout_secs = None;
let parallel = options.threads.max(1).min(config.n_starts);
let randoms = build_random_starts(config, &problem.bounds);
let m_orig = problem.num_constraints;
let warms: Vec<Option<QpWarmStart>> = std::iter::once(None)
.chain(randoms.into_iter().map(|x| {
Some(QpWarmStart {
x,
y: vec![0.0; m_orig],
mu: 1.0,
})
}))
.collect();
let shortcut_enabled = hooks.is_none_or(|h| !h.disable_deadline_shortcut);
let worker = |warm: Option<QpWarmStart>| -> SolverResult {
if shortcut_enabled && shared_opts.deadline.is_some_and(|d| Instant::now() >= d) {
return SolverResult {
status: SolveStatus::Timeout,
objective: f64::INFINITY,
..SolverResult::default()
};
}
if let Some(h) = hooks {
(h.on_solve_enter)();
}
let r = solve_one(problem, &shared_opts, warm);
if let Some(h) = hooks {
(h.on_solve_exit)();
}
r
};
let results: Vec<SolverResult> = if parallel <= 1 {
warms
.into_iter()
.take_while(|_| {
!shortcut_enabled || shared_opts.deadline.is_none_or(|d| Instant::now() < d)
})
.map(worker)
.collect()
} else {
let pool_result = hooks
.and_then(|h| h.thread_pool_factory.as_ref().map(|f| f(parallel)))
.unwrap_or_else(|| {
rayon::ThreadPoolBuilder::new()
.num_threads(parallel)
.build()
});
match pool_result {
Ok(pool) => pool.install(|| {
warms
.into_par_iter()
.map(worker)
.collect::<Vec<SolverResult>>()
}),
Err(e) => {
log::warn!(
"multistart: rayon ThreadPool build failed ({e}); \
falling back to serial execution"
);
warms.into_iter().map(worker).collect()
}
}
};
results
.into_iter()
.reduce(pick_better)
.unwrap_or_else(|| solve_one(problem, &shared_opts, None))
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
use crate::sparse::CscMatrix;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn lcg_deterministic_and_in_unit_interval() {
let mut a = 42u64;
let mut b = 42u64;
for _ in 0..1000 {
let va = lcg_uniform_01(&mut a);
let vb = lcg_uniform_01(&mut b);
assert_eq!(va, vb);
assert!((0.0..1.0).contains(&va));
}
}
#[test]
fn sample_random_box_respects_finite_bounds() {
let bounds = vec![(-1.0, 1.0), (0.0, 5.0), (-100.0, -10.0)];
let mut state = 12345u64;
for _ in 0..100 {
let x = sample_random_box(&mut state, &bounds);
assert_eq!(x.len(), 3);
for (xi, &(lb, ub)) in x.iter().zip(bounds.iter()) {
assert!(*xi >= lb && *xi <= ub);
}
}
}
#[test]
fn sample_random_box_clamps_infinite_bounds() {
let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY), (0.0, f64::INFINITY)];
let mut state = 7u64;
for _ in 0..100 {
let x = sample_random_box(&mut state, &bounds);
assert!(x[0].abs() <= MULTISTART_UNBOUNDED_RANGE);
assert!(x[1] >= 0.0 && x[1] <= MULTISTART_UNBOUNDED_RANGE);
}
}
#[test]
fn latin_hypercube_covers_each_stratum_once_per_dim() {
let bounds = vec![(0.0, 10.0), (-5.0, 5.0)];
let n_starts = 8;
let pts = latin_hypercube(99, n_starts, &bounds);
assert_eq!(pts.len(), n_starts);
for dim in 0..2 {
let (lo, hi) = bounds[dim];
let width = (hi - lo) / n_starts as f64;
let mut hit = vec![false; n_starts];
for p in pts.iter() {
let stratum = (((p[dim] - lo) / width) as usize).min(n_starts - 1);
hit[stratum] = true;
}
assert!(hit.iter().all(|&b| b), "dim {dim}: {hit:?}");
}
}
#[test]
fn pick_better_prefers_lower_obj_when_status_ties() {
let a = SolverResult {
status: SolveStatus::LocallyOptimal,
objective: -1.0,
..Default::default()
};
let b = SolverResult {
status: SolveStatus::LocallyOptimal,
objective: -5.0,
..Default::default()
};
let r = pick_better(a.clone(), b.clone());
assert_eq!(r.objective, -5.0);
let r = pick_better(b, a);
assert_eq!(r.objective, -5.0);
}
#[test]
fn pick_better_prefers_optimal_over_suboptimal_even_if_obj_worse() {
let opt = SolverResult {
status: SolveStatus::Optimal,
objective: 100.0,
..Default::default()
};
let sub = SolverResult {
status: SolveStatus::SuboptimalSolution,
objective: -100.0,
..Default::default()
};
let r = pick_better(opt.clone(), sub);
assert_eq!(r.status, SolveStatus::Optimal);
}
#[test]
fn multistart_deterministic_across_threads_count() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let cfg = MultiStartConfig {
n_starts: 8,
seed: 0xABCD,
strategy: StartStrategy::RandomBox,
};
let mut o1 = SolverOptions::default();
o1.timeout_secs = Some(20.0);
o1.threads = 1;
let mut o4 = o1.clone();
o4.threads = 4;
let r1 = solve_qp_multistart(&prob, &o1, &cfg);
let r4 = solve_qp_multistart(&prob, &o4, &cfg);
assert!(
(r1.objective - r4.objective).abs() < 1e-9,
"thread=1 vs 4 must match: r1={} r4={}",
r1.objective,
r4.objective
);
}
#[test]
fn threads_actually_parallel_and_within_limit() {
let cases = [
(2_usize, 10_usize, 2_usize, 2_usize),
(4, 10, 2, 4),
(8, 16, 2, 8),
(4, 2, 2, 2), ];
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
for (threads, n_starts, lo, hi) in cases.iter().copied() {
let active = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));
let a_enter = active.clone();
let p_enter = peak.clone();
let a_exit = active.clone();
let hooks = MultiStartHooks {
on_solve_enter: Arc::new(move || {
let n = a_enter.fetch_add(1, Ordering::SeqCst) + 1;
let mut prev = p_enter.load(Ordering::SeqCst);
while n > prev {
match p_enter.compare_exchange(prev, n, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => break,
Err(actual) => prev = actual,
}
}
std::thread::sleep(std::time::Duration::from_millis(50));
}),
on_solve_exit: Arc::new(move || {
a_exit.fetch_sub(1, Ordering::SeqCst);
}),
disable_deadline_shortcut: false,
thread_pool_factory: None,
};
let cfg = MultiStartConfig {
n_starts,
seed: 1,
strategy: StartStrategy::RandomBox,
};
let mut opts = SolverOptions::default();
opts.timeout_secs = Some(30.0);
opts.threads = threads;
solve_qp_multistart_with_hooks(&prob, &opts, &cfg, Some(&hooks));
let observed = peak.load(Ordering::SeqCst);
assert!(
observed >= lo,
"threads={threads} n_starts={n_starts}: peak={observed} expected >= {lo} (並列稼働不足)"
);
assert!(
observed <= hi,
"threads={threads} n_starts={n_starts}: peak={observed} exceeds upper {hi} (上限超過)"
);
}
}
#[test]
fn threads_eq_1_is_serial() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let active = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));
let a_enter = active.clone();
let p_enter = peak.clone();
let a_exit = active.clone();
let hooks = MultiStartHooks {
on_solve_enter: Arc::new(move || {
let n = a_enter.fetch_add(1, Ordering::SeqCst) + 1;
p_enter.fetch_max(n, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(10));
}),
on_solve_exit: Arc::new(move || {
a_exit.fetch_sub(1, Ordering::SeqCst);
}),
disable_deadline_shortcut: false,
thread_pool_factory: None,
};
let cfg = MultiStartConfig {
n_starts: 6,
seed: 1,
strategy: StartStrategy::RandomBox,
};
let mut opts = SolverOptions::default();
opts.timeout_secs = Some(20.0);
opts.threads = 1;
solve_qp_multistart_with_hooks(&prob, &opts, &cfg, Some(&hooks));
assert_eq!(peak.load(Ordering::SeqCst), 1, "threads=1 must be serial");
}
#[test]
fn deadline_shortcut_skips_post_deadline_workers() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let cfg = MultiStartConfig {
n_starts: 8,
seed: 1,
strategy: StartStrategy::RandomBox,
};
let mut opts = SolverOptions::default();
opts.deadline = Some(Instant::now() + Duration::from_millis(10));
opts.threads = 2;
let make_hooks = |disable: bool| -> (MultiStartHooks, Arc<AtomicUsize>) {
let entered = Arc::new(AtomicUsize::new(0));
let entered_clone = entered.clone();
(
MultiStartHooks {
on_solve_enter: Arc::new(move || {
entered_clone.fetch_add(1, Ordering::SeqCst);
std::thread::sleep(Duration::from_millis(80));
}),
on_solve_exit: Arc::new(|| {}),
disable_deadline_shortcut: disable,
thread_pool_factory: None,
},
entered,
)
};
let (h_on, entered_on) = make_hooks(false);
let t0_on = Instant::now();
solve_qp_multistart_with_hooks(&prob, &opts, &cfg, Some(&h_on));
let dur_on = t0_on.elapsed();
let n_entered_on = entered_on.load(Ordering::SeqCst);
let mut opts_off = opts.clone();
opts_off.deadline = Some(Instant::now() + Duration::from_millis(10));
let (h_off, entered_off) = make_hooks(true);
let t0_off = Instant::now();
solve_qp_multistart_with_hooks(&prob, &opts_off, &cfg, Some(&h_off));
let dur_off = t0_off.elapsed();
let n_entered_off = entered_off.load(Ordering::SeqCst);
assert!(
n_entered_on <= 4,
"shortcut ON: at most 4/{n_starts} workers should enter (deadline=10ms, sleep=80ms), got {n}",
n_starts = cfg.n_starts,
n = n_entered_on
);
assert_eq!(
n_entered_off,
cfg.n_starts,
"shortcut OFF: all {n_starts} workers must enter, got {n}",
n_starts = cfg.n_starts,
n = n_entered_off
);
assert!(
dur_off.as_millis() >= dur_on.as_millis() * 2,
"shortcut effect not observable: ON={:?} OFF={:?}",
dur_on,
dur_off
);
}
#[test]
fn deadline_shortcut_inactive_when_deadline_not_passed() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let entered = Arc::new(AtomicUsize::new(0));
let entered_c = entered.clone();
let hooks = MultiStartHooks {
on_solve_enter: Arc::new(move || {
entered_c.fetch_add(1, Ordering::SeqCst);
}),
on_solve_exit: Arc::new(|| {}),
disable_deadline_shortcut: false,
thread_pool_factory: None,
};
let cfg = MultiStartConfig {
n_starts: 6,
seed: 1,
strategy: StartStrategy::RandomBox,
};
let mut opts = SolverOptions::default();
opts.timeout_secs = Some(20.0); opts.threads = 2;
let r = solve_qp_multistart_with_hooks(&prob, &opts, &cfg, Some(&hooks));
assert_eq!(
entered.load(Ordering::SeqCst),
6,
"all 6 starts must run when deadline is not breached"
);
assert!(
r.objective.is_finite(),
"objective should be finite, got {}",
r.objective
);
}
#[test]
fn threadpool_build_failure_falls_back_to_serial() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let cases: &[(usize, usize)] = &[
(4, 4), (2, 8), (3, 5), ];
for &(threads, n_starts) in cases {
let hooks = MultiStartHooks {
on_solve_enter: Arc::new(|| {}),
on_solve_exit: Arc::new(|| {}),
disable_deadline_shortcut: false,
thread_pool_factory: Some(Box::new(|_n| {
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.spawn_handler(|_| -> std::io::Result<()> {
Err(std::io::Error::other("injected ThreadPool build failure"))
})
.build()
})),
};
let cfg = MultiStartConfig {
n_starts,
seed: 42,
strategy: StartStrategy::RandomBox,
};
let mut opts = SolverOptions::default();
opts.threads = threads;
opts.timeout_secs = Some(20.0);
let result = solve_qp_multistart_with_hooks(&prob, &opts, &cfg, Some(&hooks));
assert!(
result.objective.is_finite(),
"threads={threads} n_starts={n_starts}: fallback must return finite objective, got status={:?} obj={}",
result.status,
result.objective
);
}
}
#[test]
fn threadpool_fallback_result_matches_serial() {
let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, -2.0], 2, 2).unwrap();
let c = vec![0.0_f64; 2];
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 2).unwrap();
let bounds = vec![(-3.0, 3.0); 2];
let prob = QpProblem::new(q, c, a, vec![], bounds, vec![]).unwrap();
let cfg = MultiStartConfig {
n_starts: 6,
seed: 0xBEEF,
strategy: StartStrategy::RandomBox,
};
let mut opts_serial = SolverOptions::default();
opts_serial.threads = 1;
opts_serial.timeout_secs = Some(20.0);
let baseline = solve_qp_multistart(&prob, &opts_serial, &cfg);
let hooks = MultiStartHooks {
on_solve_enter: Arc::new(|| {}),
on_solve_exit: Arc::new(|| {}),
disable_deadline_shortcut: false,
thread_pool_factory: Some(Box::new(|_n| {
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.spawn_handler(|_| -> std::io::Result<()> {
Err(std::io::Error::other("injected"))
})
.build()
})),
};
let mut opts_fallback = SolverOptions::default();
opts_fallback.threads = 4;
opts_fallback.timeout_secs = Some(20.0);
let fallback_result =
solve_qp_multistart_with_hooks(&prob, &opts_fallback, &cfg, Some(&hooks));
assert!(
(baseline.objective - fallback_result.objective).abs() < 1e-9,
"fallback objective {fallback} must match serial baseline {base}",
fallback = fallback_result.objective,
base = baseline.objective,
);
}
#[test]
fn invalid_options_rejected_at_multistart_entry() {
let q = CscMatrix::from_triplets(&[0], &[0], &[-2.0], 1, 1).unwrap();
let a = CscMatrix::from_triplets(&[], &[], &[], 0, 1).unwrap();
let prob = QpProblem::new(q, vec![0.0], a, vec![], vec![(-2.0, 2.0)], vec![]).unwrap();
let cfg = MultiStartConfig {
n_starts: 3,
seed: 1,
strategy: StartStrategy::RandomBox,
};
let cases: &[(&str, SolverOptions)] = &[
(
"neg timeout_secs",
SolverOptions {
timeout_secs: Some(-1.0),
..Default::default()
},
),
(
"inf timeout_secs",
SolverOptions {
timeout_secs: Some(f64::INFINITY),
..Default::default()
},
),
(
"nan primal_tol",
SolverOptions {
primal_tol: f64::NAN,
..Default::default()
},
),
(
"zero threads",
SolverOptions {
threads: 0,
..Default::default()
},
),
];
for (label, opts) in cases {
let result = solve_qp_multistart(&prob, opts, &cfg);
assert_eq!(
result.status,
SolveStatus::NumericalError,
"solve_qp_multistart with {label} must return NumericalError (not panic)"
);
}
}
}