1use crate::{
28 SolverInterface,
29 types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
30};
31
32pub struct ProfiledSolver<S: SolverInterface> {
46 inner: S,
47 current_profile: S::Profile,
48}
49
50impl<S: SolverInterface> ProfiledSolver<S> {
51 pub fn new(inner: S) -> Self {
58 Self {
59 inner,
60 current_profile: S::Profile::default(),
61 }
62 }
63
64 pub fn set_profile(&mut self, profile: &S::Profile) {
77 if *profile == self.current_profile {
78 return;
79 }
80 self.inner.apply_profile(profile);
81 self.current_profile = *profile;
82 }
83
84 pub fn current_profile(&self) -> &S::Profile {
90 &self.current_profile
91 }
92
93 pub fn inner(&self) -> &S {
98 &self.inner
99 }
100
101 pub fn inner_mut(&mut self) -> &mut S {
106 &mut self.inner
107 }
108}
109
110impl<S: SolverInterface> SolverInterface for ProfiledSolver<S> {
112 type Profile = S::Profile;
113
114 fn apply_profile(&mut self, profile: &S::Profile) {
115 self.inner.apply_profile(profile);
116 self.current_profile = *profile;
117 }
118
119 fn load_model(&mut self, template: &StageTemplate) {
120 self.inner.load_model(template);
121 }
122
123 fn add_rows(&mut self, rows: &RowBatch) {
124 self.inner.add_rows(rows);
125 }
126
127 fn set_row_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
128 self.inner.set_row_bounds(indices, lower, upper);
129 }
130
131 fn set_col_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
132 self.inner.set_col_bounds(indices, lower, upper);
133 }
134
135 fn solve(&mut self, basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
136 self.inner.solve(basis)
143 }
144
145 fn get_basis(&mut self, out: &mut Basis) {
146 self.inner.get_basis(out);
147 }
148
149 fn statistics(&self) -> SolverStatistics {
150 self.inner.statistics()
151 }
152
153 fn statistics_into(&self, out: &mut SolverStatistics) {
154 self.inner.statistics_into(out);
155 }
156
157 fn name(&self) -> &'static str {
158 self.inner.name()
159 }
160
161 fn solver_name_version(&self) -> String {
162 self.inner.solver_name_version()
163 }
164
165 fn record_reconstruction_stats(&mut self) {
166 self.inner.record_reconstruction_stats();
167 }
168
169 fn reset_solver_state(&mut self) {
170 self.inner.reset_solver_state();
171 }
172}
173
174#[cfg(all(test, feature = "highs"))]
179mod tests {
180 use std::cell::RefCell;
181
182 use super::ProfiledSolver;
183 use crate::{
184 HighsProfile, SolverInterface,
185 types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
186 };
187
188 #[derive(Debug, Clone, PartialEq)]
192 enum RecordedCall {
193 LoadModel,
194 AddRows,
195 SetRowBounds,
196 SetColBounds,
197 Solve,
198 ApplyProfile(HighsProfile),
199 }
200
201 struct RecordingMockSolver {
208 calls: RefCell<Vec<RecordedCall>>,
209 }
210
211 impl RecordingMockSolver {
212 fn new() -> Self {
213 Self {
214 calls: RefCell::new(Vec::new()),
215 }
216 }
217
218 pub(crate) fn recorded_calls(&self) -> Vec<RecordedCall> {
220 self.calls.borrow().clone()
221 }
222 }
223
224 unsafe impl Send for RecordingMockSolver {}
232
233 impl SolverInterface for RecordingMockSolver {
234 type Profile = HighsProfile;
235
236 fn apply_profile(&mut self, profile: &HighsProfile) {
237 self.calls
238 .borrow_mut()
239 .push(RecordedCall::ApplyProfile(*profile));
240 }
241
242 fn load_model(&mut self, _template: &StageTemplate) {
243 self.calls.borrow_mut().push(RecordedCall::LoadModel);
244 }
245
246 fn add_rows(&mut self, _rows: &RowBatch) {
247 self.calls.borrow_mut().push(RecordedCall::AddRows);
248 }
249
250 fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
251 self.calls.borrow_mut().push(RecordedCall::SetRowBounds);
252 }
253
254 fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
255 self.calls.borrow_mut().push(RecordedCall::SetColBounds);
256 }
257
258 fn solve(&mut self, _basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
259 self.calls.borrow_mut().push(RecordedCall::Solve);
260 Err(SolverError::InternalError {
261 message: "mock".to_string(),
262 error_code: None,
263 })
264 }
265
266 fn get_basis(&mut self, _out: &mut Basis) {}
267
268 fn statistics(&self) -> SolverStatistics {
269 SolverStatistics::default()
270 }
271
272 fn statistics_into(&self, out: &mut SolverStatistics) {
273 out.copy_from(&SolverStatistics::default());
274 }
275
276 fn name(&self) -> &'static str {
277 "RecordingMock"
278 }
279
280 fn solver_name_version(&self) -> String {
281 "RecordingMockSolver 0.0.0".to_string()
282 }
283 }
284
285 fn filter_profile_calls(calls: &[RecordedCall]) -> Vec<&RecordedCall> {
289 calls
290 .iter()
291 .filter(|c| matches!(c, RecordedCall::ApplyProfile(_)))
292 .collect()
293 }
294
295 fn make_test_template() -> StageTemplate {
296 StageTemplate {
297 num_cols: 1,
298 num_rows: 0,
299 num_nz: 0,
300 col_starts: vec![0_i32, 0],
301 row_indices: vec![],
302 values: vec![],
303 col_lower: vec![0.0],
304 col_upper: vec![1.0],
305 objective: vec![0.0],
306 row_lower: vec![],
307 row_upper: vec![],
308 n_state: 0,
309 n_transfer: 0,
310 n_dual_relevant: 0,
311 n_hydro: 0,
312 max_par_order: 0,
313 col_scale: vec![],
314 row_scale: vec![],
315 }
316 }
317
318 fn make_test_row_batch() -> RowBatch {
319 RowBatch {
320 num_rows: 0,
321 row_starts: vec![0_i32],
322 col_indices: vec![],
323 values: vec![],
324 row_lower: vec![],
325 row_upper: vec![],
326 }
327 }
328
329 #[test]
331 fn new_issues_no_ffi_calls() {
332 let mock = RecordingMockSolver::new();
333 let solver = ProfiledSolver::new(mock);
334 let calls = solver.inner.recorded_calls();
335 assert!(
336 calls.is_empty(),
337 "expected zero calls after ProfiledSolver::new, got: {calls:?}"
338 );
339 }
340
341 #[test]
344 fn set_profile_noop_when_unchanged() {
345 let mock = RecordingMockSolver::new();
346 let mut solver = ProfiledSolver::new(mock);
347
348 solver.set_profile(&HighsProfile::default());
350
351 let calls = solver.inner.recorded_calls();
352 let profile_calls = filter_profile_calls(&calls);
353 assert!(
354 profile_calls.is_empty(),
355 "expected zero apply_profile calls when profile unchanged, got: {profile_calls:?}"
356 );
357 }
358
359 #[test]
367 fn set_profile_dispatches_apply_profile_when_changed() {
368 {
370 let mock = RecordingMockSolver::new();
371 let mut solver = ProfiledSolver::new(mock);
372 let p = HighsProfile {
373 primal_feasibility_tolerance: 1e-7,
374 ..HighsProfile::default()
375 };
376 solver.set_profile(&p);
377 let calls = solver.inner.recorded_calls();
378 let profile_calls = filter_profile_calls(&calls);
379 assert_eq!(
380 profile_calls,
381 vec![&RecordedCall::ApplyProfile(p)],
382 "expected one ApplyProfile(p) for primal-only change"
383 );
384 }
385
386 {
388 let mock = RecordingMockSolver::new();
389 let mut solver = ProfiledSolver::new(mock);
390 let p = HighsProfile {
391 dual_feasibility_tolerance: 1e-7,
392 ..HighsProfile::default()
393 };
394 solver.set_profile(&p);
395 let calls = solver.inner.recorded_calls();
396 let profile_calls = filter_profile_calls(&calls);
397 assert_eq!(
398 profile_calls,
399 vec![&RecordedCall::ApplyProfile(p)],
400 "expected one ApplyProfile(p) for dual-only change"
401 );
402 }
403
404 {
406 let mock = RecordingMockSolver::new();
407 let mut solver = ProfiledSolver::new(mock);
408 let p = HighsProfile {
409 simplex_iteration_limit: 50_000,
410 ..HighsProfile::default()
411 };
412 solver.set_profile(&p);
413 let calls = solver.inner.recorded_calls();
414 let profile_calls = filter_profile_calls(&calls);
415 assert_eq!(
416 profile_calls,
417 vec![&RecordedCall::ApplyProfile(p)],
418 "expected one ApplyProfile(p) for simplex-cap-only change"
419 );
420 }
421
422 {
424 let mock = RecordingMockSolver::new();
425 let mut solver = ProfiledSolver::new(mock);
426 let p = HighsProfile {
427 ipm_iteration_limit: 5_000,
428 ..HighsProfile::default()
429 };
430 solver.set_profile(&p);
431 let calls = solver.inner.recorded_calls();
432 let profile_calls = filter_profile_calls(&calls);
433 assert_eq!(
434 profile_calls,
435 vec![&RecordedCall::ApplyProfile(p)],
436 "expected one ApplyProfile(p) for ipm-cap-only change"
437 );
438 }
439
440 {
442 let mock = RecordingMockSolver::new();
443 let mut solver = ProfiledSolver::new(mock);
444 let p = HighsProfile {
445 simplex_dual_edge_weight_strategy: 0, ..HighsProfile::default()
447 };
448 solver.set_profile(&p);
449 let calls = solver.inner.recorded_calls();
450 let profile_calls = filter_profile_calls(&calls);
451 assert_eq!(
452 profile_calls,
453 vec![&RecordedCall::ApplyProfile(p)],
454 "expected one ApplyProfile(p) for dual-edge-weight-only change"
455 );
456 }
457
458 {
460 let mock = RecordingMockSolver::new();
461 let mut solver = ProfiledSolver::new(mock);
462 let p = HighsProfile {
463 simplex_price_strategy: 2, ..HighsProfile::default()
465 };
466 solver.set_profile(&p);
467 let calls = solver.inner.recorded_calls();
468 let profile_calls = filter_profile_calls(&calls);
469 assert_eq!(
470 profile_calls,
471 vec![&RecordedCall::ApplyProfile(p)],
472 "expected one ApplyProfile(p) for price-strategy-only change"
473 );
474 }
475 }
476
477 #[test]
481 fn set_profile_full_change_dispatches_single_apply_profile() {
482 let mock = RecordingMockSolver::new();
483 let mut solver = ProfiledSolver::new(mock);
484
485 let p = HighsProfile {
486 primal_feasibility_tolerance: 1e-7,
487 dual_feasibility_tolerance: 1e-7,
488 simplex_iteration_limit: 50_000,
489 ipm_iteration_limit: 5_000,
490 simplex_dual_edge_weight_strategy: 0, simplex_scale_strategy: 2, simplex_price_strategy: 2, };
494 solver.set_profile(&p);
495
496 let calls = solver.inner.recorded_calls();
497 let profile_calls: Vec<_> = filter_profile_calls(&calls).into_iter().cloned().collect();
498
499 assert_eq!(
500 profile_calls,
501 vec![RecordedCall::ApplyProfile(p)],
502 "expected exactly one ApplyProfile call with the complete profile"
503 );
504 }
505
506 #[test]
512 fn solver_interface_methods_forward_to_inner() {
513 let mock = RecordingMockSolver::new();
514 let mut solver = ProfiledSolver::new(mock);
515
516 let template = make_test_template();
517 let rows = make_test_row_batch();
518
519 solver.load_model(&template);
520 solver.add_rows(&rows);
521 solver.set_row_bounds(&[], &[], &[]);
522 solver.set_col_bounds(&[], &[], &[]);
523 let _ = solver.solve(None);
524
525 let calls = solver.inner.recorded_calls();
526 assert!(
527 calls.contains(&RecordedCall::LoadModel),
528 "expected LoadModel in call log, got: {calls:?}"
529 );
530 assert!(
531 calls.contains(&RecordedCall::AddRows),
532 "expected AddRows in call log, got: {calls:?}"
533 );
534 assert!(
535 calls.contains(&RecordedCall::SetRowBounds),
536 "expected SetRowBounds in call log, got: {calls:?}"
537 );
538 assert!(
539 calls.contains(&RecordedCall::SetColBounds),
540 "expected SetColBounds in call log, got: {calls:?}"
541 );
542 assert!(
543 calls.contains(&RecordedCall::Solve),
544 "expected Solve in call log, got: {calls:?}"
545 );
546 let profile_calls = filter_profile_calls(&calls);
550 assert!(
551 profile_calls.is_empty(),
552 "solve() must not trigger an ApplyProfile call, got: {calls:?}"
553 );
554 }
555}