Skip to main content

cobre_solver/
profiled.rs

1//! Generic `ProfiledSolver<S>` wrapper with per-phase LP-solver configuration.
2//!
3//! [`ProfiledSolver`] wraps any [`SolverInterface`] implementor, tracks the
4//! currently-applied [`SolveProfile`], and skips FFI option-setter calls when
5//! the new profile matches the current one (delta-only dispatch). All other
6//! [`SolverInterface`] methods are transparently forwarded to the inner solver.
7//!
8//! # Construction
9//!
10//! `ProfiledSolver::new(inner)` assumes the inner solver is in a state
11//! consistent with `SolveProfile::default()` and issues no FFI calls.
12//!
13//! # Usage
14//!
15//! ```rust
16//! use cobre_solver::{ProfiledSolver, SolveProfile, SolverInterface, HighsSolver};
17//!
18//! let inner = HighsSolver::new().expect("HiGHS init");
19//! let mut solver = ProfiledSolver::new(inner);
20//! solver.set_profile(&SolveProfile::default());
21//! assert_eq!(solver.current_profile(), &SolveProfile::default());
22//! ```
23
24use crate::{
25    SolveProfile, SolverInterface,
26    types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
27};
28
29/// Wraps any [`SolverInterface`] implementor with per-phase profile
30/// configuration.
31///
32/// Tracks the currently-applied profile and skips no-op FFI calls when the
33/// same profile is reapplied. The wrapper itself implements [`SolverInterface`]
34/// by transparently forwarding all method calls to the inner solver.
35/// [`ProfiledSolver::set_profile`] is the only non-trait-method addition.
36///
37/// # Generic parameter
38///
39/// `S` must implement [`SolverInterface`]. The wrapper is resolved at compile
40/// time (monomorphization) to preserve zero-cost forwarding on the hot path.
41pub struct ProfiledSolver<S: SolverInterface> {
42    inner: S,
43    current_profile: SolveProfile,
44}
45
46impl<S: SolverInterface> ProfiledSolver<S> {
47    /// Wrap an existing solver with the default profile.
48    ///
49    /// The wrapper does NOT issue any FFI calls on construction — the inner
50    /// solver is assumed to be in a state consistent with
51    /// `SolveProfile::default()`, which is exactly how it has been
52    /// constructed historically.
53    pub fn new(inner: S) -> Self {
54        Self {
55            inner,
56            current_profile: SolveProfile::default(),
57        }
58    }
59
60    /// Apply a new profile to the inner solver.
61    ///
62    /// Only fields that differ from `current_profile` trigger trait-method
63    /// calls on the inner solver. The dispatch order is deterministic:
64    /// primal feasibility → dual feasibility → simplex cap → IPM cap.
65    ///
66    /// If `profile == current_profile`, this method returns immediately with
67    /// zero inner method calls.
68    ///
69    /// After the call returns, `current_profile() == profile`.
70    ///
71    /// # Call-site contract
72    ///
73    /// Callers invoke this once per phase boundary. It is NOT intended to be
74    /// called inside the hot solve loop.
75    pub fn set_profile(&mut self, profile: &SolveProfile) {
76        if *profile == self.current_profile {
77            return;
78        }
79        // Exact bit-equality is intentional: `SolveProfile` is `Copy + PartialEq`
80        // precisely so that field-level delta tracking can compare stored vs. requested
81        // values without a margin-of-error. The goal is to avoid FFI calls only when
82        // the exact same bitpattern was previously applied — not to express a numerical
83        // tolerance concept. `float_cmp` is suppressed for this reason.
84        #[allow(clippy::float_cmp)]
85        if profile.primal_feasibility_tolerance != self.current_profile.primal_feasibility_tolerance
86        {
87            self.inner
88                .set_primal_feasibility_tolerance(profile.primal_feasibility_tolerance);
89        }
90        #[allow(clippy::float_cmp)]
91        if profile.dual_feasibility_tolerance != self.current_profile.dual_feasibility_tolerance {
92            self.inner
93                .set_dual_feasibility_tolerance(profile.dual_feasibility_tolerance);
94        }
95        if profile.simplex_iteration_limit != self.current_profile.simplex_iteration_limit {
96            self.inner
97                .set_simplex_iteration_limit_profile(profile.simplex_iteration_limit);
98        }
99        if profile.ipm_iteration_limit != self.current_profile.ipm_iteration_limit {
100            self.inner
101                .set_ipm_iteration_limit_profile(profile.ipm_iteration_limit);
102        }
103        self.current_profile = *profile;
104    }
105
106    /// Read-only access to the currently applied profile.
107    ///
108    /// Returns the profile that was last successfully applied via
109    /// [`ProfiledSolver::set_profile`], or `SolveProfile::default()` if no
110    /// profile has been applied yet.
111    pub fn current_profile(&self) -> &SolveProfile {
112        &self.current_profile
113    }
114
115    /// Shared reference to the wrapped inner solver.
116    ///
117    /// Intended for test code and rare adapter sites that need to inspect
118    /// mock-specific fields on the inner solver. Not used on the hot path.
119    pub fn inner(&self) -> &S {
120        &self.inner
121    }
122
123    /// Exclusive reference to the wrapped inner solver.
124    ///
125    /// Intended for test code and rare adapter sites that need to mutate
126    /// mock-specific state on the inner solver. Not used on the hot path.
127    pub fn inner_mut(&mut self) -> &mut S {
128        &mut self.inner
129    }
130}
131
132// Transparent `SolverInterface` forwarding.
133impl<S: SolverInterface> SolverInterface for ProfiledSolver<S> {
134    fn load_model(&mut self, template: &StageTemplate) {
135        self.inner.load_model(template);
136    }
137
138    fn add_rows(&mut self, rows: &RowBatch) {
139        self.inner.add_rows(rows);
140    }
141
142    fn set_row_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
143        self.inner.set_row_bounds(indices, lower, upper);
144    }
145
146    fn set_col_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]) {
147        self.inner.set_col_bounds(indices, lower, upper);
148    }
149
150    fn solve(&mut self, basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
151        self.inner.solve(basis)
152    }
153
154    fn get_basis(&mut self, out: &mut Basis) {
155        self.inner.get_basis(out);
156    }
157
158    fn statistics(&self) -> SolverStatistics {
159        self.inner.statistics()
160    }
161
162    fn name(&self) -> &'static str {
163        self.inner.name()
164    }
165
166    fn solver_name_version(&self) -> String {
167        self.inner.solver_name_version()
168    }
169
170    fn record_reconstruction_stats(&mut self) {
171        self.inner.record_reconstruction_stats();
172    }
173
174    fn set_primal_feasibility_tolerance(&mut self, value: f64) {
175        self.inner.set_primal_feasibility_tolerance(value);
176    }
177
178    fn set_dual_feasibility_tolerance(&mut self, value: f64) {
179        self.inner.set_dual_feasibility_tolerance(value);
180    }
181
182    fn set_simplex_iteration_limit_profile(&mut self, value: u32) {
183        self.inner.set_simplex_iteration_limit_profile(value);
184    }
185
186    fn set_ipm_iteration_limit_profile(&mut self, value: u32) {
187        self.inner.set_ipm_iteration_limit_profile(value);
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::cell::RefCell;
194
195    use super::ProfiledSolver;
196    use crate::{
197        SolveProfile, SolverInterface,
198        types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate},
199    };
200
201    // ── RecordingMockSolver ───────────────────────────────────────────────────
202
203    /// Recorded invocation of a [`SolverInterface`] method.
204    #[derive(Debug, Clone, PartialEq)]
205    enum RecordedCall {
206        LoadModel,
207        AddRows,
208        SetRowBounds,
209        SetColBounds,
210        Solve,
211        SetPrimalFeas(f64),
212        SetDualFeas(f64),
213        SetSimplexCap(u32),
214        SetIpmCap(u32),
215    }
216
217    /// A minimal [`SolverInterface`] implementor that records every invocation
218    /// into an interior-mutable call log.
219    ///
220    /// Returned `solve` results are always `Err(SolverError::InternalError)`
221    /// because the mock does not represent a real LP solver; callers only
222    /// inspect the call log in these unit tests.
223    struct RecordingMockSolver {
224        calls: RefCell<Vec<RecordedCall>>,
225    }
226
227    impl RecordingMockSolver {
228        fn new() -> Self {
229            Self {
230                calls: RefCell::new(Vec::new()),
231            }
232        }
233
234        /// Returns a snapshot of all recorded calls.
235        pub(crate) fn recorded_calls(&self) -> Vec<RecordedCall> {
236            self.calls.borrow().clone()
237        }
238    }
239
240    // SAFETY for `Send`: `RecordingMockSolver` is only ever constructed and
241    // used on a single thread within these unit tests. `RefCell` is not `Sync`,
242    // but the `Send` bound on `SolverInterface` merely permits transferring
243    // ownership to another thread — it does not permit concurrent access. The
244    // mock is never actually transferred across threads; the `unsafe impl Send`
245    // is required by the trait bound and is safe in this single-threaded test
246    // context.
247    unsafe impl Send for RecordingMockSolver {}
248
249    impl SolverInterface for RecordingMockSolver {
250        fn load_model(&mut self, _template: &StageTemplate) {
251            self.calls.borrow_mut().push(RecordedCall::LoadModel);
252        }
253
254        fn add_rows(&mut self, _rows: &RowBatch) {
255            self.calls.borrow_mut().push(RecordedCall::AddRows);
256        }
257
258        fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
259            self.calls.borrow_mut().push(RecordedCall::SetRowBounds);
260        }
261
262        fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {
263            self.calls.borrow_mut().push(RecordedCall::SetColBounds);
264        }
265
266        fn solve(&mut self, _basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
267            self.calls.borrow_mut().push(RecordedCall::Solve);
268            Err(SolverError::InternalError {
269                message: "mock".to_string(),
270                error_code: None,
271            })
272        }
273
274        fn get_basis(&mut self, _out: &mut Basis) {}
275
276        fn statistics(&self) -> SolverStatistics {
277            SolverStatistics::default()
278        }
279
280        fn name(&self) -> &'static str {
281            "RecordingMock"
282        }
283
284        fn solver_name_version(&self) -> String {
285            "RecordingMockSolver 0.0.0".to_string()
286        }
287
288        fn set_primal_feasibility_tolerance(&mut self, value: f64) {
289            self.calls
290                .borrow_mut()
291                .push(RecordedCall::SetPrimalFeas(value));
292        }
293
294        fn set_dual_feasibility_tolerance(&mut self, value: f64) {
295            self.calls
296                .borrow_mut()
297                .push(RecordedCall::SetDualFeas(value));
298        }
299
300        fn set_simplex_iteration_limit_profile(&mut self, value: u32) {
301            self.calls
302                .borrow_mut()
303                .push(RecordedCall::SetSimplexCap(value));
304        }
305
306        fn set_ipm_iteration_limit_profile(&mut self, value: u32) {
307            self.calls.borrow_mut().push(RecordedCall::SetIpmCap(value));
308        }
309    }
310
311    // ── Helpers ───────────────────────────────────────────────────────────────
312
313    /// Filter recorded calls to extract only profile setter calls.
314    fn filter_profile_calls(calls: &[RecordedCall]) -> Vec<&RecordedCall> {
315        calls
316            .iter()
317            .filter(|c| {
318                matches!(
319                    c,
320                    RecordedCall::SetPrimalFeas(_)
321                        | RecordedCall::SetDualFeas(_)
322                        | RecordedCall::SetSimplexCap(_)
323                        | RecordedCall::SetIpmCap(_)
324                )
325            })
326            .collect()
327    }
328
329    fn make_test_template() -> StageTemplate {
330        StageTemplate {
331            num_cols: 1,
332            num_rows: 0,
333            num_nz: 0,
334            col_starts: vec![0_i32, 0],
335            row_indices: vec![],
336            values: vec![],
337            col_lower: vec![0.0],
338            col_upper: vec![1.0],
339            objective: vec![0.0],
340            row_lower: vec![],
341            row_upper: vec![],
342            n_state: 0,
343            n_transfer: 0,
344            n_dual_relevant: 0,
345            n_hydro: 0,
346            max_par_order: 0,
347            col_scale: vec![],
348            row_scale: vec![],
349        }
350    }
351
352    fn make_test_row_batch() -> RowBatch {
353        RowBatch {
354            num_rows: 0,
355            row_starts: vec![0_i32],
356            col_indices: vec![],
357            values: vec![],
358            row_lower: vec![],
359            row_upper: vec![],
360        }
361    }
362
363    // ── AC-3 ─────────────────────────────────────────────────────────────────
364
365    /// AC-3: `ProfiledSolver::new` must not dispatch any FFI setter calls.
366    #[test]
367    fn new_issues_no_ffi_calls() {
368        let mock = RecordingMockSolver::new();
369        let solver = ProfiledSolver::new(mock);
370        // Access the inner mock's call log via the wrapper's inner field —
371        // since inner is private, we use current_profile() as the construction
372        // witness and retrieve the mock after consuming the wrapper.
373        let mock = solver.inner;
374        let calls = mock.recorded_calls();
375        assert!(
376            calls.is_empty(),
377            "expected zero calls after ProfiledSolver::new, got: {calls:?}"
378        );
379    }
380
381    // ── AC-4 ─────────────────────────────────────────────────────────────────
382
383    /// AC-4: `set_profile` with a profile equal to `current_profile` issues
384    /// zero FFI setter calls.
385    #[test]
386    fn set_profile_noop_when_unchanged() {
387        let mock = RecordingMockSolver::new();
388        let mut solver = ProfiledSolver::new(mock);
389
390        // Apply the default profile — same as the initial `current_profile`.
391        solver.set_profile(&SolveProfile::default());
392
393        let calls = solver.inner.recorded_calls();
394        let profile_calls = filter_profile_calls(&calls);
395        assert!(
396            profile_calls.is_empty(),
397            "expected zero profile setter calls when profile unchanged, got: {profile_calls:?}"
398        );
399    }
400
401    // ── AC-5 ─────────────────────────────────────────────────────────────────
402
403    /// AC-5: `set_profile` with exactly one field changed dispatches exactly
404    /// one setter call matching that field, and no other setter calls.
405    #[test]
406    fn set_profile_dispatches_only_changed_field() {
407        let default = SolveProfile::default();
408
409        // ── sub-test 1: only primal tolerance changed ──
410        {
411            let mock = RecordingMockSolver::new();
412            let mut solver = ProfiledSolver::new(mock);
413            let p = SolveProfile {
414                primal_feasibility_tolerance: 1e-7,
415                ..default
416            };
417            solver.set_profile(&p);
418            let calls = solver.inner.recorded_calls();
419            let setter_calls = filter_profile_calls(&calls);
420            assert_eq!(
421                setter_calls,
422                vec![&RecordedCall::SetPrimalFeas(1e-7)],
423                "expected only SetPrimalFeas(1e-7) for primal-only change"
424            );
425        }
426
427        // ── sub-test 2: only dual tolerance changed ──
428        {
429            let mock = RecordingMockSolver::new();
430            let mut solver = ProfiledSolver::new(mock);
431            let p = SolveProfile {
432                dual_feasibility_tolerance: 1e-7,
433                ..default
434            };
435            solver.set_profile(&p);
436            let calls = solver.inner.recorded_calls();
437            let setter_calls = filter_profile_calls(&calls);
438            assert_eq!(
439                setter_calls,
440                vec![&RecordedCall::SetDualFeas(1e-7)],
441                "expected only SetDualFeas(1e-7) for dual-only change"
442            );
443        }
444
445        // ── sub-test 3: only simplex cap changed ──
446        {
447            let mock = RecordingMockSolver::new();
448            let mut solver = ProfiledSolver::new(mock);
449            let p = SolveProfile {
450                simplex_iteration_limit: 50_000,
451                ..default
452            };
453            solver.set_profile(&p);
454            let calls = solver.inner.recorded_calls();
455            let setter_calls = filter_profile_calls(&calls);
456            assert_eq!(
457                setter_calls,
458                vec![&RecordedCall::SetSimplexCap(50_000)],
459                "expected only SetSimplexCap(50_000) for simplex-only change"
460            );
461        }
462
463        // ── sub-test 4: only IPM cap changed ──
464        {
465            let mock = RecordingMockSolver::new();
466            let mut solver = ProfiledSolver::new(mock);
467            let p = SolveProfile {
468                ipm_iteration_limit: 5_000,
469                ..default
470            };
471            solver.set_profile(&p);
472            let calls = solver.inner.recorded_calls();
473            let setter_calls = filter_profile_calls(&calls);
474            assert_eq!(
475                setter_calls,
476                vec![&RecordedCall::SetIpmCap(5_000)],
477                "expected only SetIpmCap(5_000) for ipm-only change"
478            );
479        }
480    }
481
482    // ── AC-6 ─────────────────────────────────────────────────────────────────
483
484    /// AC-6: When all four profile fields differ, `set_profile` dispatches
485    /// exactly four setter calls in the deterministic order:
486    /// `SetPrimalFeas` → `SetDualFeas` → `SetSimplexCap` → `SetIpmCap`.
487    #[test]
488    fn set_profile_full_change_uses_deterministic_order() {
489        let mock = RecordingMockSolver::new();
490        let mut solver = ProfiledSolver::new(mock);
491
492        let p = SolveProfile {
493            primal_feasibility_tolerance: 1e-7,
494            dual_feasibility_tolerance: 1e-7,
495            simplex_iteration_limit: 50_000,
496            ipm_iteration_limit: 5_000,
497        };
498        solver.set_profile(&p);
499
500        let calls = solver.inner.recorded_calls();
501        let setter_calls: Vec<_> = filter_profile_calls(&calls).into_iter().cloned().collect();
502
503        assert_eq!(
504            setter_calls,
505            vec![
506                RecordedCall::SetPrimalFeas(1e-7),
507                RecordedCall::SetDualFeas(1e-7),
508                RecordedCall::SetSimplexCap(50_000),
509                RecordedCall::SetIpmCap(5_000),
510            ],
511            "setter calls must appear in deterministic order: primal, dual, simplex, ipm"
512        );
513    }
514
515    // ── AC-7 ─────────────────────────────────────────────────────────────────
516
517    /// AC-7: `ProfiledSolver<S>` forwards `load_model`, `add_rows`,
518    /// `set_row_bounds`, `set_col_bounds`, and `solve` transparently to the
519    /// inner solver.
520    #[test]
521    fn solver_interface_methods_forward_to_inner() {
522        let mock = RecordingMockSolver::new();
523        let mut solver = ProfiledSolver::new(mock);
524
525        let template = make_test_template();
526        let rows = make_test_row_batch();
527
528        solver.load_model(&template);
529        solver.add_rows(&rows);
530        solver.set_row_bounds(&[], &[], &[]);
531        solver.set_col_bounds(&[], &[], &[]);
532        let _ = solver.solve(None);
533
534        let calls = solver.inner.recorded_calls();
535        assert!(
536            calls.contains(&RecordedCall::LoadModel),
537            "expected LoadModel in call log, got: {calls:?}"
538        );
539        assert!(
540            calls.contains(&RecordedCall::AddRows),
541            "expected AddRows in call log, got: {calls:?}"
542        );
543        assert!(
544            calls.contains(&RecordedCall::SetRowBounds),
545            "expected SetRowBounds in call log, got: {calls:?}"
546        );
547        assert!(
548            calls.contains(&RecordedCall::SetColBounds),
549            "expected SetColBounds in call log, got: {calls:?}"
550        );
551        assert!(
552            calls.contains(&RecordedCall::Solve),
553            "expected Solve in call log, got: {calls:?}"
554        );
555    }
556}