use core::fmt;
#[derive(Debug, Clone)]
pub struct Basis {
pub col_status: Vec<i32>,
pub row_status: Vec<i32>,
}
impl Basis {
#[must_use]
pub fn new(num_cols: usize, num_rows: usize) -> Self {
Self {
col_status: vec![0_i32; num_cols],
row_status: vec![0_i32; num_rows],
}
}
}
#[derive(Debug, Clone)]
pub struct LpSolution {
pub objective: f64,
pub primal: Vec<f64>,
pub dual: Vec<f64>,
pub reduced_costs: Vec<f64>,
pub iterations: u64,
pub solve_time_seconds: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct SolutionView<'a> {
pub objective: f64,
pub primal: &'a [f64],
pub dual: &'a [f64],
pub reduced_costs: &'a [f64],
pub iterations: u64,
pub solve_time_seconds: f64,
}
impl SolutionView<'_> {
#[must_use]
pub fn to_owned(&self) -> LpSolution {
LpSolution {
objective: self.objective,
primal: self.primal.to_vec(),
dual: self.dual.to_vec(),
reduced_costs: self.reduced_costs.to_vec(),
iterations: self.iterations,
solve_time_seconds: self.solve_time_seconds,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SolverStatistics {
pub solve_count: u64,
pub success_count: u64,
pub failure_count: u64,
pub total_iterations: u64,
pub retry_count: u64,
pub total_solve_time_seconds: f64,
pub basis_consistency_failures: u64,
pub first_try_successes: u64,
pub basis_offered: u64,
pub load_model_count: u64,
pub total_load_model_time_seconds: f64,
pub total_set_bounds_time_seconds: f64,
pub total_basis_set_time_seconds: f64,
pub basis_reconstructions: u64,
pub retry_level_histogram: Vec<u64>,
}
impl SolverStatistics {
pub fn copy_from(&mut self, src: &SolverStatistics) {
self.solve_count = src.solve_count;
self.success_count = src.success_count;
self.failure_count = src.failure_count;
self.total_iterations = src.total_iterations;
self.retry_count = src.retry_count;
self.total_solve_time_seconds = src.total_solve_time_seconds;
self.basis_consistency_failures = src.basis_consistency_failures;
self.first_try_successes = src.first_try_successes;
self.basis_offered = src.basis_offered;
self.load_model_count = src.load_model_count;
self.total_load_model_time_seconds = src.total_load_model_time_seconds;
self.total_set_bounds_time_seconds = src.total_set_bounds_time_seconds;
self.total_basis_set_time_seconds = src.total_basis_set_time_seconds;
self.basis_reconstructions = src.basis_reconstructions;
self.retry_level_histogram
.resize(src.retry_level_histogram.len(), 0);
self.retry_level_histogram
.copy_from_slice(&src.retry_level_histogram);
}
}
#[derive(Debug, Clone)]
pub struct StageTemplate {
pub num_cols: usize,
pub num_rows: usize,
pub num_nz: usize,
pub col_starts: Vec<i32>,
pub row_indices: Vec<i32>,
pub values: Vec<f64>,
pub col_lower: Vec<f64>,
pub col_upper: Vec<f64>,
pub objective: Vec<f64>,
pub row_lower: Vec<f64>,
pub row_upper: Vec<f64>,
pub n_state: usize,
pub n_transfer: usize,
pub n_dual_relevant: usize,
pub n_hydro: usize,
pub max_par_order: usize,
pub col_scale: Vec<f64>,
pub row_scale: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct RowBatch {
pub num_rows: usize,
pub row_starts: Vec<i32>,
pub col_indices: Vec<i32>,
pub values: Vec<f64>,
pub row_lower: Vec<f64>,
pub row_upper: Vec<f64>,
}
impl StageTemplate {
#[must_use]
pub fn empty() -> Self {
Self {
num_cols: 0,
num_rows: 0,
num_nz: 0,
col_starts: Vec::new(),
row_indices: Vec::new(),
values: Vec::new(),
col_lower: Vec::new(),
col_upper: Vec::new(),
objective: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
}
impl RowBatch {
pub fn clear(&mut self) {
self.num_rows = 0;
self.row_starts.clear();
self.col_indices.clear();
self.values.clear();
self.row_lower.clear();
self.row_upper.clear();
}
}
#[derive(Debug)]
pub enum SolverError {
Infeasible,
Unbounded,
NumericalDifficulty {
message: String,
},
TimeLimitExceeded {
elapsed_seconds: f64,
},
IterationLimit {
iterations: u64,
},
InternalError {
message: String,
error_code: Option<i32>,
},
Unsupported(&'static str),
BasisInconsistent {
num_row: i64,
total_basic: i64,
col_basic: i64,
row_basic: i64,
},
BasisRowCountMismatch {
lp_rows: usize,
basis_rows: usize,
},
}
impl fmt::Display for SolverError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Infeasible => write!(f, "LP is infeasible"),
Self::Unbounded => write!(f, "LP is unbounded"),
Self::NumericalDifficulty { message } => {
write!(f, "numerical difficulty: {message}")
}
Self::TimeLimitExceeded { elapsed_seconds } => {
write!(f, "time limit exceeded after {elapsed_seconds:.3}s")
}
Self::IterationLimit { iterations } => {
write!(f, "iteration limit reached after {iterations} iterations")
}
Self::InternalError {
message,
error_code,
} => match error_code {
Some(code) => write!(f, "internal solver error (code {code}): {message}"),
None => write!(f, "internal solver error: {message}"),
},
Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
Self::BasisInconsistent {
num_row,
total_basic,
col_basic,
row_basic,
} => write!(
f,
"basis inconsistent: num_row={num_row}, total_basic={total_basic} (col_basic={col_basic}, row_basic={row_basic})"
),
Self::BasisRowCountMismatch {
lp_rows,
basis_rows,
} => write!(
f,
"basis row count mismatch: lp_rows={lp_rows}, basis_rows={basis_rows}"
),
}
}
}
impl std::error::Error for SolverError {}
#[cfg(test)]
mod tests {
use super::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate};
#[test]
fn test_basis_new_dimensions_and_zero_fill() {
let rb = Basis::new(3, 2);
assert_eq!(rb.col_status.len(), 3);
assert_eq!(rb.row_status.len(), 2);
assert!(rb.col_status.iter().all(|&v| v == 0_i32));
assert!(rb.row_status.iter().all(|&v| v == 0_i32));
}
#[test]
fn test_basis_new_empty() {
let rb = Basis::new(0, 0);
assert!(rb.col_status.is_empty());
assert!(rb.row_status.is_empty());
}
#[test]
fn test_basis_debug_and_clone() {
let rb = Basis::new(2, 1);
assert!(!format!("{rb:?}").is_empty());
let cloned = rb.clone();
assert_eq!(cloned.col_status, rb.col_status);
assert_eq!(cloned.row_status, rb.row_status);
let mut cloned2 = rb.clone();
cloned2.col_status[0] = 1_i32;
assert_eq!(rb.col_status[0], 0_i32);
}
#[test]
fn test_solver_error_display_infeasible() {
let msg = format!("{}", SolverError::Infeasible);
assert!(msg.contains("infeasible"));
}
#[test]
fn test_solver_error_display_all_variants() {
let variants = [
SolverError::Infeasible,
SolverError::Unbounded,
SolverError::NumericalDifficulty {
message: "factorization failed".to_string(),
},
SolverError::TimeLimitExceeded {
elapsed_seconds: 60.0,
},
SolverError::IterationLimit { iterations: 10_000 },
SolverError::InternalError {
message: "segfault in HiGHS".to_string(),
error_code: Some(-1),
},
SolverError::BasisInconsistent {
num_row: 2,
total_basic: 5,
col_basic: 3,
row_basic: 2,
},
SolverError::BasisRowCountMismatch {
lp_rows: 3,
basis_rows: 2,
},
];
let messages: Vec<String> = variants.iter().map(|err| format!("{err}")).collect();
for i in 0..messages.len() {
for j in (i + 1)..messages.len() {
assert_ne!(messages[i], messages[j]);
}
}
}
#[test]
fn test_solver_error_is_std_error() {
let err = SolverError::InternalError {
message: "test".to_string(),
error_code: None,
};
let _: &dyn std::error::Error = &err;
}
#[test]
fn test_solver_statistics_default_all_zero() {
let stats = SolverStatistics::default();
assert_eq!(stats.solve_count, 0);
assert_eq!(stats.success_count, 0);
assert_eq!(stats.failure_count, 0);
assert_eq!(stats.total_iterations, 0);
assert_eq!(stats.retry_count, 0);
assert_eq!(stats.total_solve_time_seconds, 0.0);
assert_eq!(stats.basis_consistency_failures, 0);
assert_eq!(stats.first_try_successes, 0);
assert_eq!(stats.basis_offered, 0);
assert_eq!(stats.total_load_model_time_seconds, 0.0);
assert_eq!(stats.total_set_bounds_time_seconds, 0.0);
assert_eq!(stats.basis_reconstructions, 0);
assert!(stats.retry_level_histogram.is_empty());
}
fn make_fixture_statistics() -> SolverStatistics {
SolverStatistics {
solve_count: 11,
success_count: 9,
failure_count: 2,
total_iterations: 3456,
retry_count: 5,
total_solve_time_seconds: 1.25,
basis_consistency_failures: 1,
first_try_successes: 7,
basis_offered: 8,
load_model_count: 4,
total_load_model_time_seconds: 0.5,
total_set_bounds_time_seconds: 0.25,
total_basis_set_time_seconds: 0.125,
basis_reconstructions: 6,
retry_level_histogram: vec![1, 0, 2, 0, 3, 0, 0, 0, 4, 0, 0, 5],
}
}
fn assert_statistics_eq(a: &SolverStatistics, b: &SolverStatistics) {
assert_eq!(a.solve_count, b.solve_count);
assert_eq!(a.success_count, b.success_count);
assert_eq!(a.failure_count, b.failure_count);
assert_eq!(a.total_iterations, b.total_iterations);
assert_eq!(a.retry_count, b.retry_count);
assert_eq!(a.total_solve_time_seconds, b.total_solve_time_seconds);
assert_eq!(a.basis_consistency_failures, b.basis_consistency_failures);
assert_eq!(a.first_try_successes, b.first_try_successes);
assert_eq!(a.basis_offered, b.basis_offered);
assert_eq!(a.load_model_count, b.load_model_count);
assert_eq!(
a.total_load_model_time_seconds,
b.total_load_model_time_seconds
);
assert_eq!(
a.total_set_bounds_time_seconds,
b.total_set_bounds_time_seconds
);
assert_eq!(
a.total_basis_set_time_seconds,
b.total_basis_set_time_seconds
);
assert_eq!(a.basis_reconstructions, b.basis_reconstructions);
assert_eq!(a.retry_level_histogram, b.retry_level_histogram);
}
#[test]
fn test_solver_statistics_copy_from_equals_clone() {
let src = make_fixture_statistics();
let mut buf = SolverStatistics::default();
buf.copy_from(&src);
assert_statistics_eq(&buf, &src);
let mut buf2 = SolverStatistics {
retry_level_histogram: vec![99; 5],
..Default::default()
};
buf2.copy_from(&src);
assert_statistics_eq(&buf2, &src);
}
#[test]
fn test_solver_statistics_copy_from_no_realloc_second_call() {
let src = make_fixture_statistics();
let mut buf = SolverStatistics {
retry_level_histogram: vec![0; src.retry_level_histogram.len()],
..Default::default()
};
buf.copy_from(&src);
let ptr_before = buf.retry_level_histogram.as_ptr();
buf.copy_from(&src);
let ptr_after = buf.retry_level_histogram.as_ptr();
assert_eq!(
ptr_before, ptr_after,
"second copy_from must not reallocate the histogram"
);
assert_statistics_eq(&buf, &src);
}
fn make_fixture_stage_template() -> StageTemplate {
StageTemplate {
num_cols: 3,
num_rows: 2,
num_nz: 3,
col_starts: vec![0_i32, 2, 2, 3],
row_indices: vec![0_i32, 1, 1],
values: vec![1.0, 2.0, 1.0],
col_lower: vec![0.0, 0.0, 0.0],
col_upper: vec![10.0, f64::INFINITY, 8.0],
objective: vec![0.0, 1.0, 50.0],
row_lower: vec![6.0, 14.0],
row_upper: vec![6.0, 14.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
#[test]
fn test_stage_template_construction() {
let tmpl = make_fixture_stage_template();
assert_eq!(tmpl.num_cols, 3);
assert_eq!(tmpl.num_rows, 2);
assert_eq!(tmpl.num_nz, 3);
assert_eq!(tmpl.col_starts, vec![0_i32, 2, 2, 3]);
assert_eq!(tmpl.row_indices, vec![0_i32, 1, 1]);
assert_eq!(tmpl.values, vec![1.0, 2.0, 1.0]);
assert_eq!(tmpl.col_lower, vec![0.0, 0.0, 0.0]);
assert_eq!(tmpl.col_upper[0], 10.0);
assert!(tmpl.col_upper[1].is_infinite() && tmpl.col_upper[1] > 0.0);
assert_eq!(tmpl.col_upper[2], 8.0);
assert_eq!(tmpl.objective, vec![0.0, 1.0, 50.0]);
assert_eq!(tmpl.row_lower, vec![6.0, 14.0]);
assert_eq!(tmpl.row_upper, vec![6.0, 14.0]);
assert_eq!(tmpl.n_state, 1);
assert_eq!(tmpl.n_transfer, 0);
assert_eq!(tmpl.n_dual_relevant, 1);
assert_eq!(tmpl.n_hydro, 1);
assert_eq!(tmpl.max_par_order, 0);
}
#[test]
fn test_solver_error_display_all_branches() {
let cases = vec![
("Infeasible", SolverError::Infeasible, "infeasible"),
("Unbounded", SolverError::Unbounded, "unbounded"),
(
"NumericalDifficulty",
SolverError::NumericalDifficulty {
message: "singular matrix".to_string(),
},
"singular matrix",
),
(
"TimeLimitExceeded",
SolverError::TimeLimitExceeded {
elapsed_seconds: 60.0,
},
"60.000s",
),
(
"IterationLimit",
SolverError::IterationLimit { iterations: 10_000 },
"10000 iterations",
),
(
"InternalError/None",
SolverError::InternalError {
message: "unknown failure".to_string(),
error_code: None,
},
"unknown failure",
),
(
"InternalError/Some",
SolverError::InternalError {
message: "segfault in HiGHS".to_string(),
error_code: Some(-1),
},
"code -1",
),
(
"BasisInconsistent",
SolverError::BasisInconsistent {
num_row: 2,
total_basic: 5,
col_basic: 3,
row_basic: 2,
},
"num_row=2",
),
(
"BasisRowCountMismatch",
SolverError::BasisRowCountMismatch {
lp_rows: 3,
basis_rows: 2,
},
"lp_rows=3",
),
];
for (name, err, expected_text) in cases {
let msg = format!("{err}");
assert!(!msg.is_empty());
assert!(
msg.contains(expected_text),
"{name} missing '{expected_text}'"
);
}
}
#[test]
fn test_solver_error_is_std_error_all_variants() {
let errors: Vec<SolverError> = vec![
SolverError::Infeasible,
SolverError::Unbounded,
SolverError::NumericalDifficulty {
message: "test".to_string(),
},
SolverError::TimeLimitExceeded {
elapsed_seconds: 1.0,
},
SolverError::IterationLimit { iterations: 1 },
SolverError::InternalError {
message: "test".to_string(),
error_code: None,
},
SolverError::InternalError {
message: "test".to_string(),
error_code: Some(-1),
},
SolverError::BasisInconsistent {
num_row: 2,
total_basic: 5,
col_basic: 3,
row_basic: 2,
},
SolverError::BasisRowCountMismatch {
lp_rows: 3,
basis_rows: 2,
},
];
for err in &errors {
let _: &dyn std::error::Error = err;
}
}
#[test]
fn test_solution_view_to_owned() {
let primal = [1.0, 2.0];
let dual = [3.0];
let rc = [4.0, 5.0];
let view = SolutionView {
objective: 42.0,
primal: &primal,
dual: &dual,
reduced_costs: &rc,
iterations: 7,
solve_time_seconds: 0.5,
};
let owned = view.to_owned();
assert_eq!(owned.objective, 42.0);
assert_eq!(owned.primal, vec![1.0, 2.0]);
assert_eq!(owned.dual, vec![3.0]);
assert_eq!(owned.reduced_costs, vec![4.0, 5.0]);
assert_eq!(owned.iterations, 7);
assert_eq!(owned.solve_time_seconds, 0.5);
}
#[test]
fn test_solution_view_is_copy() {
let primal = [1.0];
let dual = [2.0];
let rc = [3.0];
let view = SolutionView {
objective: 0.0,
primal: &primal,
dual: &dual,
reduced_costs: &rc,
iterations: 0,
solve_time_seconds: 0.0,
};
let copy = view;
assert_eq!(view.objective, copy.objective);
}
#[test]
fn test_row_batch_construction() {
let batch = RowBatch {
num_rows: 2,
row_starts: vec![0_i32, 2, 4],
col_indices: vec![0_i32, 1, 0, 1],
values: vec![-5.0, 1.0, 3.0, 1.0],
row_lower: vec![20.0, 80.0],
row_upper: vec![f64::INFINITY, f64::INFINITY],
};
assert_eq!(batch.num_rows, 2);
assert_eq!(batch.row_starts.len(), 3);
assert_eq!(batch.row_starts, vec![0_i32, 2, 4]);
assert_eq!(batch.col_indices, vec![0_i32, 1, 0, 1]);
assert_eq!(batch.values, vec![-5.0, 1.0, 3.0, 1.0]);
assert_eq!(batch.row_lower, vec![20.0, 80.0]);
assert!(batch.row_upper[0].is_infinite() && batch.row_upper[0] > 0.0);
assert!(batch.row_upper[1].is_infinite() && batch.row_upper[1] > 0.0);
}
}