use crate::{
SolverInterface,
types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
};
pub struct ProfiledSolver<S: SolverInterface> {
inner: S,
current_profile: S::Profile,
}
impl<S: SolverInterface> ProfiledSolver<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
current_profile: S::Profile::default(),
}
}
pub fn set_profile(&mut self, profile: &S::Profile) {
if *profile == self.current_profile {
return;
}
self.inner.apply_profile(profile);
self.current_profile = *profile;
}
pub fn current_profile(&self) -> &S::Profile {
&self.current_profile
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S: SolverInterface> SolverInterface for ProfiledSolver<S> {
type Profile = S::Profile;
fn apply_profile(&mut self, profile: &S::Profile) {
self.inner.apply_profile(profile);
self.current_profile = *profile;
}
fn load_model(&mut self, template: &StageTemplate) {
self.inner.load_model(template);
}
fn add_rows(&mut self, rows: &RowBatch) {
self.inner.add_rows(rows);
}
fn set_row_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
self.inner.set_row_bounds(indices, lower, upper);
}
fn set_col_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
self.inner.set_col_bounds(indices, lower, upper);
}
fn solve(&mut self, basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
self.inner.solve(basis)
}
fn get_basis(&mut self, out: &mut Basis) {
self.inner.get_basis(out);
}
fn statistics(&self) -> SolverStatistics {
self.inner.statistics()
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn solver_name_version(&self) -> String {
self.inner.solver_name_version()
}
fn record_reconstruction_stats(&mut self) {
self.inner.record_reconstruction_stats();
}
fn set_primal_feasibility_tolerance(&mut self, value: f64) {
self.inner.set_primal_feasibility_tolerance(value);
}
fn set_dual_feasibility_tolerance(&mut self, value: f64) {
self.inner.set_dual_feasibility_tolerance(value);
}
fn set_simplex_iteration_limit_profile(&mut self, value: u32) {
self.inner.set_simplex_iteration_limit_profile(value);
}
fn set_ipm_iteration_limit_profile(&mut self, value: u32) {
self.inner.set_ipm_iteration_limit_profile(value);
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use super::ProfiledSolver;
use crate::{
HighsProfile, SolverInterface,
types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
};
#[derive(Debug, Clone, PartialEq)]
enum RecordedCall {
LoadModel,
AddRows,
SetRowBounds,
SetColBounds,
Solve,
ApplyProfile(HighsProfile),
}
struct RecordingMockSolver {
calls: RefCell<Vec<RecordedCall>>,
}
impl RecordingMockSolver {
fn new() -> Self {
Self {
calls: RefCell::new(Vec::new()),
}
}
pub(crate) fn recorded_calls(&self) -> Vec<RecordedCall> {
self.calls.borrow().clone()
}
}
unsafe impl Send for RecordingMockSolver {}
impl SolverInterface for RecordingMockSolver {
type Profile = HighsProfile;
fn apply_profile(&mut self, profile: &HighsProfile) {
self.calls
.borrow_mut()
.push(RecordedCall::ApplyProfile(*profile));
}
fn load_model(&mut self, _template: &StageTemplate) {
self.calls.borrow_mut().push(RecordedCall::LoadModel);
}
fn add_rows(&mut self, _rows: &RowBatch) {
self.calls.borrow_mut().push(RecordedCall::AddRows);
}
fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
self.calls.borrow_mut().push(RecordedCall::SetRowBounds);
}
fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
self.calls.borrow_mut().push(RecordedCall::SetColBounds);
}
fn solve(&mut self, _basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
self.calls.borrow_mut().push(RecordedCall::Solve);
Err(SolverError::InternalError {
message: "mock".to_string(),
error_code: None,
})
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn name(&self) -> &'static str {
"RecordingMock"
}
fn solver_name_version(&self) -> String {
"RecordingMockSolver 0.0.0".to_string()
}
fn set_primal_feasibility_tolerance(&mut self, _value: f64) {}
fn set_dual_feasibility_tolerance(&mut self, _value: f64) {}
fn set_simplex_iteration_limit_profile(&mut self, _value: u32) {}
fn set_ipm_iteration_limit_profile(&mut self, _value: u32) {}
}
fn filter_profile_calls(calls: &[RecordedCall]) -> Vec<&RecordedCall> {
calls
.iter()
.filter(|c| matches!(c, RecordedCall::ApplyProfile(_)))
.collect()
}
fn make_test_template() -> StageTemplate {
StageTemplate {
num_cols: 1,
num_rows: 0,
num_nz: 0,
col_starts: vec![0_i32, 0],
row_indices: vec![],
values: vec![],
col_lower: vec![0.0],
col_upper: vec![1.0],
objective: vec![0.0],
row_lower: vec![],
row_upper: vec![],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: vec![],
row_scale: vec![],
}
}
fn make_test_row_batch() -> RowBatch {
RowBatch {
num_rows: 0,
row_starts: vec![0_i32],
col_indices: vec![],
values: vec![],
row_lower: vec![],
row_upper: vec![],
}
}
#[test]
fn new_issues_no_ffi_calls() {
let mock = RecordingMockSolver::new();
let solver = ProfiledSolver::new(mock);
let mock = solver.inner;
let calls = mock.recorded_calls();
assert!(
calls.is_empty(),
"expected zero calls after ProfiledSolver::new, got: {calls:?}"
);
}
#[test]
fn set_profile_noop_when_unchanged() {
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
solver.set_profile(&HighsProfile::default());
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert!(
profile_calls.is_empty(),
"expected zero apply_profile calls when profile unchanged, got: {profile_calls:?}"
);
}
#[test]
fn set_profile_dispatches_apply_profile_when_changed() {
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
primal_feasibility_tolerance: 1e-7,
..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for primal-only change"
);
}
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
dual_feasibility_tolerance: 1e-7,
..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for dual-only change"
);
}
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
simplex_iteration_limit: 50_000,
..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for simplex-cap-only change"
);
}
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
ipm_iteration_limit: 5_000,
..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for ipm-cap-only change"
);
}
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
simplex_dual_edge_weight_strategy: 0, ..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for dual-edge-weight-only change"
);
}
{
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
simplex_price_strategy: 2, ..HighsProfile::default()
};
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls = filter_profile_calls(&calls);
assert_eq!(
profile_calls,
vec![&RecordedCall::ApplyProfile(p)],
"expected one ApplyProfile(p) for price-strategy-only change"
);
}
}
#[test]
fn set_profile_full_change_dispatches_single_apply_profile() {
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let p = HighsProfile {
primal_feasibility_tolerance: 1e-7,
dual_feasibility_tolerance: 1e-7,
simplex_iteration_limit: 50_000,
ipm_iteration_limit: 5_000,
simplex_dual_edge_weight_strategy: 0, simplex_scale_strategy: 2, simplex_price_strategy: 2, };
solver.set_profile(&p);
let calls = solver.inner.recorded_calls();
let profile_calls: Vec<_> = filter_profile_calls(&calls).into_iter().cloned().collect();
assert_eq!(
profile_calls,
vec![RecordedCall::ApplyProfile(p)],
"expected exactly one ApplyProfile call with the complete profile"
);
}
#[test]
fn solver_interface_methods_forward_to_inner() {
let mock = RecordingMockSolver::new();
let mut solver = ProfiledSolver::new(mock);
let template = make_test_template();
let rows = make_test_row_batch();
solver.load_model(&template);
solver.add_rows(&rows);
solver.set_row_bounds(&[], &[], &[]);
solver.set_col_bounds(&[], &[], &[]);
let _ = solver.solve(None);
let calls = solver.inner.recorded_calls();
assert!(
calls.contains(&RecordedCall::LoadModel),
"expected LoadModel in call log, got: {calls:?}"
);
assert!(
calls.contains(&RecordedCall::AddRows),
"expected AddRows in call log, got: {calls:?}"
);
assert!(
calls.contains(&RecordedCall::SetRowBounds),
"expected SetRowBounds in call log, got: {calls:?}"
);
assert!(
calls.contains(&RecordedCall::SetColBounds),
"expected SetColBounds in call log, got: {calls:?}"
);
assert!(
calls.contains(&RecordedCall::Solve),
"expected Solve in call log, got: {calls:?}"
);
let profile_calls = filter_profile_calls(&calls);
assert!(
profile_calls.is_empty(),
"solve() must not trigger an ApplyProfile call, got: {calls:?}"
);
}
}