use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum MusFunction {
Size1,
Slice,
Cake,
Quick,
}
impl MusFunction {
fn label(self) -> &'static str {
match self {
MusFunction::Size1 => "size_1 (tiny scan)",
MusFunction::Slice => "slice",
MusFunction::Cake => "cake",
MusFunction::Quick => "quick",
}
}
fn index(self) -> usize {
match self {
MusFunction::Size1 => 0,
MusFunction::Slice => 1,
MusFunction::Cake => 2,
MusFunction::Quick => 3,
}
}
}
#[derive(Debug, Clone)]
pub enum MusOutcome {
Found(usize),
NotFound,
Timeout,
}
#[derive(Debug, Clone)]
pub struct MusRecord {
pub duration: Duration,
pub outcome: MusOutcome,
pub function: MusFunction,
}
#[derive(Debug, Clone, Default)]
pub struct PhaseStats {
pub calls: u64,
pub total_time: Duration,
}
impl PhaseStats {
fn record(&mut self, d: Duration) {
self.calls += 1;
self.total_time += d;
}
}
const BUCKETS: &[(u128, &str)] = &[
(1, "< 1ms"),
(5, "1–5ms"),
(20, "5–20ms"),
(100, "20–100ms"),
(500, "100–500ms"),
(2_000, "500ms–2s"),
(u128::MAX, "> 2s"),
];
const NUM_TIME_BUCKETS: usize = BUCKETS.len();
fn bucket_index(d: Duration) -> usize {
let ms = d.as_millis();
BUCKETS
.iter()
.position(|&(limit, _)| ms < limit)
.unwrap_or(BUCKETS.len() - 1)
}
const CONFLICT_BUCKETS: &[(usize, &str)] = &[
(1, "0 conflicts"),
(10, "1–9"),
(100, "10–99"),
(1_000, "100–999"),
(10_000, "1k–9k"),
(usize::MAX, "≥ 10k"),
];
const NUM_CONF_BUCKETS: usize = CONFLICT_BUCKETS.len();
fn conflict_bucket_index(n: usize) -> usize {
CONFLICT_BUCKETS
.iter()
.position(|&(limit, _)| n < limit)
.unwrap_or(CONFLICT_BUCKETS.len() - 1)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SatResult {
Sat,
Unsat,
Interrupted,
}
impl From<rustsat::solvers::SolverResult> for SatResult {
fn from(r: rustsat::solvers::SolverResult) -> Self {
match r {
rustsat::solvers::SolverResult::Sat => SatResult::Sat,
rustsat::solvers::SolverResult::Unsat => SatResult::Unsat,
rustsat::solvers::SolverResult::Interrupted => SatResult::Interrupted,
}
}
}
#[derive(Debug, Clone)]
pub struct SatCallRecord {
pub duration: Duration,
pub conflicts: usize,
pub result: SatResult,
}
#[derive(Debug, Clone, Copy)]
struct SatStats {
total: usize,
n_sat: usize,
n_unsat: usize,
n_interrupted: usize,
total_time: Duration,
total_conflicts: usize,
time_by_result: [[u32; 3]; NUM_TIME_BUCKETS],
time_bucket_dur: [Duration; NUM_TIME_BUCKETS],
conf_by_result: [[u32; 3]; NUM_CONF_BUCKETS],
conf_bucket_total: [usize; NUM_CONF_BUCKETS],
}
const SAT_STATS_ZERO: SatStats = SatStats {
total: 0,
n_sat: 0,
n_unsat: 0,
n_interrupted: 0,
total_time: Duration::ZERO,
total_conflicts: 0,
time_by_result: [[0; 3]; NUM_TIME_BUCKETS],
time_bucket_dur: [Duration::ZERO; NUM_TIME_BUCKETS],
conf_by_result: [[0; 3]; NUM_CONF_BUCKETS],
conf_bucket_total: [0; NUM_CONF_BUCKETS],
};
impl SatStats {
fn record(&mut self, duration: Duration, conflicts: usize, result: SatResult) {
self.total += 1;
self.total_time += duration;
self.total_conflicts += conflicts;
let ri = match result {
SatResult::Sat => {
self.n_sat += 1;
0
}
SatResult::Unsat => {
self.n_unsat += 1;
1
}
SatResult::Interrupted => {
self.n_interrupted += 1;
2
}
};
let tb = bucket_index(duration);
self.time_by_result[tb][ri] += 1;
self.time_bucket_dur[tb] += duration;
let cb = conflict_bucket_index(conflicts);
self.conf_by_result[cb][ri] += 1;
self.conf_bucket_total[cb] += conflicts;
}
fn print_summary(&self) {
eprintln!("=== SAT Call Statistics ===");
eprintln!("Total SAT calls: {}", self.total);
if self.total == 0 {
eprintln!("===========================");
return;
}
eprintln!(
"sat {} unsat {} interrupted {} \
total time {:.3}s total conflicts {}",
self.n_sat,
self.n_unsat,
self.n_interrupted,
self.total_time.as_secs_f64(),
self.total_conflicts,
);
eprintln!();
eprintln!(" Time distribution:");
eprintln!(
" {:>12} {:>7} {:>7} {:>11} {:>7} {:>9}",
"bucket", "sat", "unsat", "interrupted", "total", "time (s)"
);
eprintln!(" {}", "-".repeat(62));
for (b, &(_, label)) in BUCKETS.iter().enumerate() {
let [sat, unsat, int] = self.time_by_result[b];
let row = sat + unsat + int;
if row == 0 {
continue;
}
eprintln!(
" {:>12} {:>7} {:>7} {:>11} {:>7} {:>9.3}",
label,
sat,
unsat,
int,
row,
self.time_bucket_dur[b].as_secs_f64(),
);
}
eprintln!();
eprintln!(" Conflict distribution:");
eprintln!(
" {:>14} {:>7} {:>7} {:>11} {:>7} {:>12}",
"bucket", "sat", "unsat", "interrupted", "total", "conflicts"
);
eprintln!(" {}", "-".repeat(66));
for (b, &(_, label)) in CONFLICT_BUCKETS.iter().enumerate() {
let [sat, unsat, int] = self.conf_by_result[b];
let row = sat + unsat + int;
if row == 0 {
continue;
}
eprintln!(
" {:>14} {:>7} {:>7} {:>11} {:>7} {:>12}",
label, sat, unsat, int, row, self.conf_bucket_total[b],
);
}
eprintln!();
eprintln!("===========================");
}
}
static SAT_STATS: Mutex<SatStats> = Mutex::new(SAT_STATS_ZERO);
pub fn record_sat_call(
duration: Duration,
conflicts: usize,
result: rustsat::solvers::SolverResult,
) {
if let Ok(mut s) = SAT_STATS.lock() {
s.record(duration, conflicts, result.into());
}
}
pub fn reset_sat_stats() {
if let Ok(mut s) = SAT_STATS.lock() {
*s = SAT_STATS_ZERO;
}
}
pub fn print_sat_stats() {
if let Ok(s) = SAT_STATS.lock() {
s.print_summary();
}
}
const NUM_MUS_FUNCTIONS: usize = 4;
#[derive(Debug, Clone, Copy)]
struct MusFuncStats {
total: usize,
n_found: usize,
n_notfound: usize,
n_timeout: usize,
total_time: Duration,
size_min: usize, size_max: usize,
size_sum: usize,
time_by_outcome: [[u32; 3]; NUM_TIME_BUCKETS],
time_bucket_dur: [Duration; NUM_TIME_BUCKETS],
}
const MUS_FUNC_ZERO: MusFuncStats = MusFuncStats {
total: 0,
n_found: 0,
n_notfound: 0,
n_timeout: 0,
total_time: Duration::ZERO,
size_min: usize::MAX,
size_max: 0,
size_sum: 0,
time_by_outcome: [[0; 3]; NUM_TIME_BUCKETS],
time_bucket_dur: [Duration::ZERO; NUM_TIME_BUCKETS],
};
impl MusFuncStats {
fn record(&mut self, duration: Duration, outcome: &MusOutcome) {
self.total += 1;
self.total_time += duration;
let oi = match outcome {
MusOutcome::Found(n) => {
self.n_found += 1;
if *n < self.size_min {
self.size_min = *n;
}
if *n > self.size_max {
self.size_max = *n;
}
self.size_sum += n;
0
}
MusOutcome::NotFound => {
self.n_notfound += 1;
1
}
MusOutcome::Timeout => {
self.n_timeout += 1;
2
}
};
let tb = bucket_index(duration);
self.time_by_outcome[tb][oi] += 1;
self.time_bucket_dur[tb] += duration;
}
fn print_summary(&self, label: &str) {
if self.total == 0 {
return;
}
eprintln!();
eprintln!(
"── {} ─── {} calls, {:.3}s total (found {}, not-found {}, timeout {})",
label,
self.total,
self.total_time.as_secs_f64(),
self.n_found,
self.n_notfound,
self.n_timeout,
);
if self.n_found > 0 {
let avg = self.size_sum as f64 / self.n_found as f64;
eprintln!(
" MUS sizes: min={} max={} avg={:.1}",
self.size_min, self.size_max, avg
);
}
eprintln!(
" {:>12} {:>7} {:>10} {:>8} {:>7} {:>9}",
"bucket", "found", "not-found", "timeout", "total", "time (s)"
);
eprintln!(" {}", "-".repeat(62));
for (b, &(_, label)) in BUCKETS.iter().enumerate() {
let [found, notfound, timeout] = self.time_by_outcome[b];
let row_total = found + notfound + timeout;
if row_total == 0 {
continue;
}
eprintln!(
" {:>12} {:>7} {:>10} {:>8} {:>7} {:>9.3}",
label,
found,
notfound,
timeout,
row_total,
self.time_bucket_dur[b].as_secs_f64(),
);
}
}
}
#[derive(Debug, Clone)]
pub struct MusStats {
per_func: [MusFuncStats; NUM_MUS_FUNCTIONS],
pub phase_batch_mus: PhaseStats,
pub phase_solve_step: PhaseStats,
}
const MUS_STATS_ZERO: MusStats = MusStats {
per_func: [MUS_FUNC_ZERO; NUM_MUS_FUNCTIONS],
phase_batch_mus: PhaseStats {
calls: 0,
total_time: Duration::ZERO,
},
phase_solve_step: PhaseStats {
calls: 0,
total_time: Duration::ZERO,
},
};
impl Default for MusStats {
fn default() -> Self {
MUS_STATS_ZERO.clone()
}
}
impl MusStats {
fn record_search(&mut self, duration: Duration, outcome: MusOutcome, function: MusFunction) {
self.per_func[function.index()].record(duration, &outcome);
}
fn total_searches(&self) -> usize {
self.per_func.iter().map(|f| f.total).sum()
}
pub fn print_summary(&self) {
let total = self.total_searches();
eprintln!("=== MUS Statistics ===");
eprintln!(
"Total searches: {} Batch phase: {} calls {:.3}s Solve steps: {} {:.3}s",
total,
self.phase_batch_mus.calls,
self.phase_batch_mus.total_time.as_secs_f64(),
self.phase_solve_step.calls,
self.phase_solve_step.total_time.as_secs_f64(),
);
if total == 0 {
eprintln!("(no MUS searches recorded)");
eprintln!("======================");
return;
}
let functions = [
MusFunction::Size1,
MusFunction::Slice,
MusFunction::Cake,
MusFunction::Quick,
];
for func in functions {
self.per_func[func.index()].print_summary(func.label());
}
eprintln!();
eprintln!("======================");
}
}
static MUS_STATS: Mutex<MusStats> = Mutex::new(MUS_STATS_ZERO);
pub fn record_mus_search(duration: Duration, outcome: MusOutcome, function: MusFunction) {
if let Ok(mut stats) = MUS_STATS.lock() {
stats.record_search(duration, outcome, function);
}
}
pub fn record_batch_mus_phase(duration: Duration) {
if let Ok(mut stats) = MUS_STATS.lock() {
stats.phase_batch_mus.record(duration);
}
}
pub fn record_solve_step(duration: Duration) {
if let Ok(mut stats) = MUS_STATS.lock() {
stats.phase_solve_step.record(duration);
}
}
pub fn reset_mus_stats() {
if let Ok(mut stats) = MUS_STATS.lock() {
*stats = MUS_STATS_ZERO.clone();
}
}
pub fn print_mus_stats() {
if let Ok(stats) = MUS_STATS.lock() {
stats.print_summary();
}
}
pub struct PhaseTimer {
start: Instant,
kind: PhaseKind,
}
pub enum PhaseKind {
BatchMus,
SolveStep,
}
impl PhaseTimer {
pub fn batch_mus() -> Self {
Self {
start: Instant::now(),
kind: PhaseKind::BatchMus,
}
}
pub fn solve_step() -> Self {
Self {
start: Instant::now(),
kind: PhaseKind::SolveStep,
}
}
}
impl Drop for PhaseTimer {
fn drop(&mut self) {
let d = self.start.elapsed();
match self.kind {
PhaseKind::BatchMus => record_batch_mus_phase(d),
PhaseKind::SolveStep => record_solve_step(d),
}
}
}