Skip to main content

cobre_solver/
trait_def.rs

1//! The [`SolverInterface`] trait definition.
2//!
3//! This module defines the central abstraction through which optimization
4//! algorithms interact with LP solvers.
5
6use crate::types::{Basis, RowBatch, SolutionView, SolverError, SolverStatistics, StageTemplate};
7
8/// Backend-agnostic interface for LP solver instances.
9///
10/// # Design
11///
12/// The trait is resolved as a **generic type parameter at compile time**
13/// (compile-time monomorphization for FFI-wrapping trait; see docs/adr/003-compile-time-solver.md),
14/// not as `dyn SolverInterface`. This monomorphization approach
15/// eliminates virtual dispatch overhead on the hot path, where tens of millions
16/// of LP solves occur during a single training run. The training loop is
17/// parameterized as `fn train<S: SolverInterface>(solver_factory: impl Fn() -> S, ...)`.
18///
19/// # Thread Safety
20///
21/// The trait requires `Send` but not `Sync`. `Send` allows solver instances to
22/// be transferred to worker threads during thread pool initialization. The
23/// absence of `Sync` prevents concurrent access, which matches the reality of
24/// C-library solver handles (`HiGHS`, CLP): they maintain mutable internal state
25/// (factorization workspace, working arrays) that is not thread-safe. Each
26/// worker thread owns exactly one solver instance for the duration of the
27/// training run, following the thread-local workspace pattern described in
28/// Solver Workspaces SS1.1.
29///
30/// # Mutability Convention
31///
32/// - Mutating methods (`load_model`, `add_rows`, `set_row_bounds`,
33///   `set_col_bounds`, `solve`, `solve_with_basis`, `reset`) take `&mut self`.
34/// - Methods that write to internal scratch buffers (`get_basis`) take `&mut self`.
35/// - Read-only query methods (`statistics`, `name`) take `&self`.
36///
37/// # Error Recovery Contract
38///
39/// When `solve` or `solve_with_basis` returns `Err`, the solver's
40/// internal state is unspecified. The **caller** is responsible for calling
41/// `reset()` before reusing the instance for another solve sequence. Failing to
42/// call `reset()` after an error may produce incorrect results or panics.
43///
44/// # Usage as a Generic Bound
45///
46/// ```rust
47/// use cobre_solver::{SolverInterface, SolutionView, SolverError};
48///
49/// fn run_solve<S: SolverInterface>(solver: &mut S) -> Result<SolutionView<'_>, SolverError> {
50///     solver.solve()
51/// }
52/// ```
53///
54/// See [Solver Interface Trait SS1](../../../cobre-docs/src/specs/architecture/solver-interface-trait.md)
55/// and [Solver Interface Trait SS5](../../../cobre-docs/src/specs/architecture/solver-interface-trait.md)
56/// for the dispatch mechanism rationale.
57pub trait SolverInterface: Send {
58    /// Bulk-loads a pre-assembled structural LP (first step of rebuild sequence).
59    ///
60    /// Replaces any previous model. Validates template is a valid CSC matrix
61    /// with `num_cols > 0` and `num_rows > 0` (panic on violation).
62    ///
63    /// See Solver Interface Trait SS2.1.
64    fn load_model(&mut self, template: &StageTemplate);
65
66    /// Appends constraint rows to the dynamic constraint region (step 2 of rebuild).
67    ///
68    /// Requires [`load_model`](Self::load_model) called first and `cuts` to have
69    /// valid CSR data with column indices in `[0, num_cols)` (panic on violation).
70    ///
71    /// See Solver Interface Trait SS2.2.
72    fn add_rows(&mut self, cuts: &RowBatch);
73
74    /// Updates row bounds (step 3 of rebuild; patching for scenario realization).
75    ///
76    /// `indices`, `lower`, and `upper` must have equal length, with all indices
77    /// referencing valid rows and bounds finite. For equality constraints, set
78    /// `lower[i] == upper[i]`. Panics if lengths differ or indices are out-of-bounds.
79    ///
80    /// See Solver Interface Trait SS2.3.
81    fn set_row_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]);
82
83    /// Updates column bounds (per-scenario variable bound patching).
84    ///
85    /// `indices`, `lower`, and `upper` must have equal length, with all indices
86    /// referencing valid columns and bounds finite. Panics if lengths differ or
87    /// indices are out-of-bounds.
88    ///
89    /// See Solver Interface Trait SS2.3a.
90    fn set_col_bounds(&mut self, indices: &[usize], lower: &[f64], upper: &[f64]);
91
92    /// Solves the LP, returning a zero-copy view or terminal error after retry exhaustion.
93    ///
94    /// Hot-path method encapsulating internal retry logic. Requires [`Self::load_model`]
95    /// called first and scenario patches applied. On error, caller must call
96    /// [`Self::reset`] before reusing. The returned [`SolutionView`] borrows
97    /// solver-internal buffers and is valid until the next `&mut self` call. Call
98    /// [`SolutionView::to_owned`] when the solution must outlive the borrow.
99    ///
100    /// # Errors
101    ///
102    /// Returns `Err(SolverError)` when all internal retry attempts exhausted.
103    /// Possible variants: [`SolverError::Infeasible`], [`SolverError::Unbounded`],
104    /// [`SolverError::NumericalDifficulty`], [`SolverError::TimeLimitExceeded`],
105    /// [`SolverError::IterationLimit`], or [`SolverError::InternalError`].
106    ///
107    /// See Solver Interface Trait SS2.4.
108    fn solve(&mut self) -> Result<SolutionView<'_>, SolverError>;
109
110    /// Clears internal solver state for error recovery or LP structure change.
111    ///
112    /// Requires [`Self::load_model`] before next solve. Preserves `SolverStatistics`
113    /// counters; does not zero them.
114    ///
115    /// See Solver Interface Trait SS2.6.
116    fn reset(&mut self);
117
118    /// Writes solver-native `i32` status codes into a caller-owned [`Basis`] buffer.
119    ///
120    /// The caller pre-allocates a [`Basis`] with [`Basis::new`] and reuses it
121    /// across iterations, eliminating per-element enum translation overhead.
122    ///
123    /// The buffer is not resized by this method. The implementation writes into
124    /// the first `num_cols` entries of `out.col_status` and the first `num_rows`
125    /// entries of `out.row_status`. Panics if no model is loaded.
126    ///
127    /// See Solver Interface Trait SS2.7.
128    fn get_basis(&mut self, out: &mut Basis);
129
130    /// Injects a basis and solves, returning a zero-copy [`SolutionView`].
131    ///
132    /// Status codes in `basis` are injected directly without per-element enum
133    /// translation. On success the returned view borrows solver-internal buffers
134    /// and is valid until the next `&mut self` call. Call [`SolutionView::to_owned`]
135    /// when the solution must outlive the borrow.
136    ///
137    /// # Errors
138    ///
139    /// Same error contract as [`solve`](Self::solve).
140    ///
141    /// See Solver Interface Trait SS2.5.
142    fn solve_with_basis(&mut self, basis: &Basis) -> Result<SolutionView<'_>, SolverError>;
143
144    /// Returns accumulated solve metrics (snapshot of monotonically increasing counters).
145    ///
146    /// Statistics accumulate since construction; [`Self::reset`] does not zero them.
147    /// All fields non-negative.
148    ///
149    /// See Solver Interface Trait SS2.8.
150    fn statistics(&self) -> SolverStatistics;
151
152    /// Returns a static string identifying the solver backend (e.g., `"HiGHS"`).
153    ///
154    /// Used for logging, diagnostics, and checkpoint metadata.
155    ///
156    /// See Solver Interface Trait SS2.9.
157    fn name(&self) -> &'static str;
158}
159
160#[cfg(test)]
161mod tests {
162    use super::SolverInterface;
163
164    // Verify trait is usable as a generic bound (compile-time monomorphization).
165    fn accepts_solver<S: SolverInterface>(_: &S) {}
166
167    struct NoopSolver;
168
169    impl SolverInterface for NoopSolver {
170        fn load_model(&mut self, _template: &crate::types::StageTemplate) {}
171
172        fn add_rows(&mut self, _cuts: &crate::types::RowBatch) {}
173
174        fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
175
176        fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
177
178        fn solve(&mut self) -> Result<crate::types::SolutionView<'_>, crate::types::SolverError> {
179            Err(crate::types::SolverError::InternalError {
180                message: "noop".to_string(),
181                error_code: None,
182            })
183        }
184
185        fn reset(&mut self) {}
186
187        fn get_basis(&mut self, _out: &mut crate::types::Basis) {}
188
189        fn solve_with_basis(
190            &mut self,
191            _basis: &crate::types::Basis,
192        ) -> Result<crate::types::SolutionView<'_>, crate::types::SolverError> {
193            Err(crate::types::SolverError::InternalError {
194                message: "noop".to_string(),
195                error_code: None,
196            })
197        }
198
199        fn statistics(&self) -> crate::types::SolverStatistics {
200            crate::types::SolverStatistics::default()
201        }
202
203        fn name(&self) -> &'static str {
204            "Noop"
205        }
206    }
207
208    fn assert_send<T: Send>() {}
209
210    #[test]
211    fn test_trait_compiles_as_generic_bound() {
212        accepts_solver(&NoopSolver);
213    }
214
215    #[test]
216    fn test_solver_interface_send_bound() {
217        assert_send::<NoopSolver>();
218    }
219
220    #[test]
221    fn test_noop_solver_name() {
222        let name = NoopSolver.name();
223        assert_eq!(name, "Noop");
224        assert!(!name.is_empty());
225    }
226
227    #[test]
228    fn test_noop_solver_statistics_initial() {
229        let stats = NoopSolver.statistics();
230        assert_eq!(stats.solve_count, 0);
231        assert_eq!(stats.success_count, 0);
232        assert_eq!(stats.failure_count, 0);
233        assert_eq!(stats.total_iterations, 0);
234        assert_eq!(stats.retry_count, 0);
235        assert_eq!(stats.total_solve_time_seconds, 0.0);
236    }
237
238    #[test]
239    fn test_noop_solver_get_basis_noop() {
240        use crate::types::Basis;
241
242        let mut solver = NoopSolver;
243        let mut raw = Basis::new(3, 2);
244        raw.col_status.iter_mut().for_each(|v| *v = 99_i32);
245        raw.row_status.iter_mut().for_each(|v| *v = 99_i32);
246        solver.get_basis(&mut raw);
247        assert!(raw.col_status.iter().all(|&v| v == 99_i32));
248        assert!(raw.row_status.iter().all(|&v| v == 99_i32));
249    }
250
251    #[test]
252    fn test_noop_solver_solve_with_basis_returns_internal_error() {
253        use crate::types::{Basis, SolverError};
254
255        let mut solver = NoopSolver;
256        let raw = Basis::new(0, 0);
257        let result = solver.solve_with_basis(&raw);
258        assert!(matches!(result, Err(SolverError::InternalError { .. })));
259    }
260
261    #[test]
262    fn test_noop_solver_all_methods() {
263        use crate::types::{RowBatch, SolverError, StageTemplate};
264
265        let template = StageTemplate {
266            num_cols: 1,
267            num_rows: 0,
268            num_nz: 0,
269            col_starts: vec![0_i32, 0],
270            row_indices: vec![],
271            values: vec![],
272            col_lower: vec![0.0],
273            col_upper: vec![1.0],
274            objective: vec![1.0],
275            row_lower: vec![],
276            row_upper: vec![],
277            n_state: 0,
278            n_transfer: 0,
279            n_dual_relevant: 0,
280            n_hydro: 0,
281            max_par_order: 0,
282            col_scale: Vec::new(),
283            row_scale: Vec::new(),
284        };
285
286        let batch = RowBatch {
287            num_rows: 0,
288            row_starts: vec![0_i32],
289            col_indices: vec![],
290            values: vec![],
291            row_lower: vec![],
292            row_upper: vec![],
293        };
294
295        let mut solver = NoopSolver;
296        solver.load_model(&template);
297        solver.add_rows(&batch);
298        solver.set_row_bounds(&[], &[], &[]);
299        solver.set_col_bounds(&[], &[], &[]);
300
301        let result = solver.solve();
302        assert!(matches!(result, Err(SolverError::InternalError { .. })));
303
304        solver.reset();
305    }
306}