Skip to main content

cobre_solver/
baking.rs

1//! CSR-to-CSC template baking for reusable LP stage templates.
2//!
3//! Provides [`bake_rows_into_template`], which merges a CSC base template with a
4//! CSR row batch into a larger CSC template. The result can be loaded directly via
5//! [`crate::SolverInterface::load_model`] without a subsequent `add_rows` call.
6//!
7//! # Algorithm
8//!
9//! The merge runs in two sequential passes over columns:
10//!
11//! 1. **Count pass**: for each CSR row `r`, for each non-zero `k`, increment a
12//!    per-column scratch counter `cut_nz_per_col[col_indices[k]]`.
13//! 2. **Emit pass**: walk columns `0..num_cols`; for each column `j`, emit the
14//!    original base entries from `base.col_starts[j]..base.col_starts[j+1]` followed
15//!    by all CSR entries whose column index equals `j`, in ascending CSR row order.
16//!
17//! The five intermediate buffers used by these passes are caller-owned and
18//! reused via [`BakingScratch`]: the caller constructs one `BakingScratch` and
19//! passes `&mut` to every call, so a steady-state bake performs no temporary
20//! heap allocation once the scratch capacities stabilize.
21//!
22//! # CSC Ordering Convention
23//!
24//! Within each column the original base rows appear first (in their original CSC
25//! order), followed by the appended rows in ascending CSR row order. `HiGHS` does not
26//! require sorted per-column row indices, but this convention is maintained for
27//! reproducibility and ease of debugging.
28//!
29//! See [Solver Abstraction SS11.1](../../../cobre-docs/src/specs/architecture/solver-abstraction.md).
30
31use crate::types::{RowBatch, StageTemplate};
32
33/// Caller-owned reusable scratch buffers for [`bake_rows_into_template`].
34///
35/// The merge needs five intermediate buffers, all sized from the inputs of the
36/// current call. Constructing one `BakingScratch` and passing `&mut` to every
37/// bake reuses these allocations: each buffer is `clear()`-ed and re-grown at
38/// the start of a call without ever calling `shrink_to_fit`, so at steady state
39/// (once the largest seen template/batch has been processed) a bake allocates
40/// no temporaries.
41///
42/// The fields are private; the struct is an opaque buffer bag. The buffers are,
43/// in element type and per-call length:
44///
45/// - `cut_nz_per_col: Vec<u32>` — length `num_cols`; per-column count of CSR
46///   contributions (count pass).
47/// - `col_list_start: Vec<u32>` — length `num_cols + 1`; prefix-sum column-offset
48///   table into the flat CSR-grouped buffers.
49/// - `col_list_row: Vec<i32>` — length `rows_nnz`; flat row-index buffer grouped
50///   by column.
51/// - `col_list_val: Vec<f64>` — length `rows_nnz`; flat value buffer grouped by
52///   column.
53/// - `write_cursor: Vec<u32>` — length `num_cols`; per-column write offset within
54///   the flat buffers.
55#[derive(Debug, Default)]
56pub struct BakingScratch {
57    /// Per-column count of appended CSR-row contributions (count pass).
58    cut_nz_per_col: Vec<u32>,
59    /// Prefix-sum column-offset table into the flat CSR-grouped buffers.
60    col_list_start: Vec<u32>,
61    /// Flat row-index buffer grouped by column.
62    col_list_row: Vec<i32>,
63    /// Flat value buffer grouped by column.
64    col_list_val: Vec<f64>,
65    /// Per-column write offset within the flat buffers.
66    write_cursor: Vec<u32>,
67}
68
69impl BakingScratch {
70    /// Construct an empty scratch. All buffers grow lazily on first bake.
71    #[must_use]
72    pub fn new() -> Self {
73        Self::default()
74    }
75}
76
77/// Merge a CSC base template with a CSR row batch into an output CSC template.
78///
79/// After return, `out` is a valid [`StageTemplate`] in CSC form with
80/// `out.num_rows == base.num_rows + rows.num_rows`.
81///
82/// # Buffer Reuse
83///
84/// `out` is cleared and refilled on every call without calling `shrink_to_fit`.
85/// Passing the same buffer on every iteration reuses allocations with zero
86/// additional allocation at steady state once capacity stabilizes.
87///
88/// The five intermediate buffers are owned by the caller-supplied `scratch`
89/// (see [`BakingScratch`]) and follow the identical `clear()`-then-regrow
90/// discipline: pass the same `BakingScratch` across calls to reuse them.
91///
92/// # Preconditions
93///
94/// Checked via `debug_assert!` in debug builds:
95/// - `base.col_starts.len() == base.num_cols + 1` and sentinel `== base.num_nz`
96/// - `base.row_indices.len() == base.values.len() == base.num_nz`
97/// - `base` column/row bound arrays and objective each have the correct length
98/// - `rows.row_starts[0] == 0` (when `rows.num_rows > 0`)
99/// - Every `col_indices[k] < base.num_cols`
100///
101/// # Panics
102///
103/// Panics if `base.num_nz + rows_nnz` exceeds `i32::MAX` (`HiGHS` API limit).
104#[allow(clippy::too_many_lines)] // complex data-structure merge; extracting sub-functions would obscure the algorithm
105pub fn bake_rows_into_template(
106    base: &StageTemplate,
107    rows: &RowBatch,
108    out: &mut StageTemplate,
109    scratch: &mut BakingScratch,
110) {
111    // Precondition guards (debug builds only).
112    #[allow(clippy::cast_sign_loss)]
113    {
114        debug_assert_eq!(
115            base.col_starts.len(),
116            base.num_cols + 1,
117            "base.col_starts.len()={} but num_cols+1={}",
118            base.col_starts.len(),
119            base.num_cols + 1
120        );
121        debug_assert_eq!(
122            base.col_starts.last().copied().unwrap_or(0) as usize,
123            base.num_nz,
124            "base.col_starts[num_cols] != base.num_nz"
125        );
126        debug_assert_eq!(base.row_indices.len(), base.num_nz);
127        debug_assert_eq!(base.values.len(), base.num_nz);
128        debug_assert_eq!(base.col_lower.len(), base.num_cols);
129        debug_assert_eq!(base.col_upper.len(), base.num_cols);
130        debug_assert_eq!(base.objective.len(), base.num_cols);
131        debug_assert_eq!(base.row_lower.len(), base.num_rows);
132        debug_assert_eq!(base.row_upper.len(), base.num_rows);
133        debug_assert!(
134            base.col_scale.is_empty() || base.col_scale.len() == base.num_cols,
135            "base.col_scale must be empty or length num_cols"
136        );
137        debug_assert!(
138            base.row_scale.is_empty() || base.row_scale.len() == base.num_rows,
139            "base.row_scale must be empty or length num_rows"
140        );
141
142        if rows.num_rows > 0 {
143            debug_assert_eq!(
144                rows.row_starts.len(),
145                rows.num_rows + 1,
146                "rows.row_starts.len()={} but num_rows+1={}",
147                rows.row_starts.len(),
148                rows.num_rows + 1
149            );
150            debug_assert_eq!(
151                rows.row_starts[0], 0,
152                "RowBatch invariant: row_starts[0] must be 0"
153            );
154            debug_assert_eq!(rows.row_lower.len(), rows.num_rows);
155            debug_assert_eq!(rows.row_upper.len(), rows.num_rows);
156
157            let rows_nnz = rows.row_starts[rows.num_rows] as usize;
158            debug_assert_eq!(rows.col_indices.len(), rows_nnz);
159            debug_assert_eq!(rows.values.len(), rows_nnz);
160
161            #[cfg(debug_assertions)]
162            for &col in &rows.col_indices {
163                debug_assert!(
164                    (col as usize) < base.num_cols,
165                    "col_indices[k]={col} >= base.num_cols={}",
166                    base.num_cols
167                );
168            }
169        }
170    }
171
172    // Compute total nnz and validate it fits in i32.
173    #[allow(clippy::cast_sign_loss)]
174    let rows_nnz = if rows.num_rows > 0 {
175        rows.row_starts[rows.num_rows] as usize
176    } else {
177        0
178    };
179    let total_nnz = base.num_nz + rows_nnz;
180
181    #[allow(clippy::expect_used)]
182    let total_nnz_i32 = i32::try_from(total_nnz).expect("total nnz exceeds i32::MAX");
183
184    let num_cols = base.num_cols;
185    let num_rows = base.num_rows + rows.num_rows;
186
187    // Pass 1: count CSR row contributions per column.
188    // clear() then resize(..., 0) guarantees all num_cols entries start at 0,
189    // identical to the old `vec![0u32; num_cols]`.
190    scratch.cut_nz_per_col.clear();
191    scratch.cut_nz_per_col.resize(num_cols, 0u32);
192    #[allow(clippy::cast_sign_loss)]
193    for &col in &rows.col_indices {
194        scratch.cut_nz_per_col[col as usize] += 1;
195    }
196
197    // Clear buffers (no shrink_to_fit — preserve capacity).
198    out.col_starts.clear();
199    out.row_indices.clear();
200    out.values.clear();
201    out.col_lower.clear();
202    out.col_upper.clear();
203    out.objective.clear();
204    out.col_scale.clear();
205    out.row_lower.clear();
206    out.row_upper.clear();
207    out.row_scale.clear();
208
209    // Write scalar fields.
210    out.num_cols = num_cols;
211    out.num_rows = num_rows;
212    out.num_nz = total_nnz;
213    out.n_state = base.n_state;
214    out.n_transfer = base.n_transfer;
215    out.n_dual_relevant = base.n_dual_relevant;
216    out.n_hydro = base.n_hydro;
217    out.max_par_order = base.max_par_order;
218
219    // Copy column-bound and objective arrays from base.
220    out.col_lower.extend_from_slice(&base.col_lower);
221    out.col_upper.extend_from_slice(&base.col_upper);
222    out.objective.extend_from_slice(&base.objective);
223    out.col_scale.extend_from_slice(&base.col_scale);
224
225    // Populate row_lower and row_upper.
226    out.row_lower.extend_from_slice(&base.row_lower);
227    out.row_lower.extend_from_slice(&rows.row_lower);
228    out.row_upper.extend_from_slice(&base.row_upper);
229    out.row_upper.extend_from_slice(&rows.row_upper);
230
231    // Populate row_scale: copy base (if non-empty) and append 1.0 for new rows.
232    // StageTemplate invariant (types.rs): when non-empty, row_scale.len() must
233    // equal num_rows. When the base has scaling and cuts are appended, extend
234    // to base+rows; when the base has none but rows are appended, materialise
235    // the full scale vector as 1.0 (base rows inherit the "no-op" scale).
236    if !base.row_scale.is_empty() {
237        out.row_scale.extend_from_slice(&base.row_scale);
238        out.row_scale
239            .resize(out.row_scale.len() + rows.num_rows, 1.0_f64);
240    } else if rows.num_rows > 0 {
241        out.row_scale.resize(base.num_rows + rows.num_rows, 1.0_f64);
242    }
243
244    // Pass 2: build col_starts, row_indices, values in column order.
245    // Compute column start offsets (prefix sum of cut_nz_per_col).
246    scratch.col_list_start.clear();
247    scratch.col_list_start.reserve(num_cols + 1);
248    let mut running = 0u32;
249    for &count in &scratch.cut_nz_per_col {
250        scratch.col_list_start.push(running);
251        running += count;
252    }
253    scratch.col_list_start.push(running);
254
255    // Flat scratch buffers for (row_index, value) pairs grouped by column.
256    // clear() then resize matches the old `vec![0; rows_nnz]` / `vec![0; num_cols]`
257    // exactly; cut_nz_per_col entries above are fully overwritten before read.
258    scratch.col_list_row.clear();
259    scratch.col_list_row.resize(rows_nnz, 0i32);
260    scratch.col_list_val.clear();
261    scratch.col_list_val.resize(rows_nnz, 0.0f64);
262    scratch.write_cursor.clear();
263    scratch.write_cursor.resize(num_cols, 0u32);
264
265    // Fill scratch buffers by scanning CSR rows in ascending order.
266    #[allow(clippy::cast_sign_loss)]
267    for r in 0..rows.num_rows {
268        let start = rows.row_starts[r] as usize;
269        let end = rows.row_starts[r + 1] as usize;
270        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
271        let row_i32 = (base.num_rows + r) as i32;
272        for k in start..end {
273            let j = rows.col_indices[k] as usize;
274            let pos = (scratch.col_list_start[j] + scratch.write_cursor[j]) as usize;
275            scratch.col_list_row[pos] = row_i32;
276            scratch.col_list_val[pos] = rows.values[k];
277            scratch.write_cursor[j] += 1;
278        }
279    }
280
281    // Emit col_starts, row_indices, values in column order.
282    let mut nz_cursor: i32 = 0;
283    #[allow(clippy::cast_sign_loss)]
284    for j in 0..num_cols {
285        out.col_starts.push(nz_cursor);
286
287        let base_start = base.col_starts[j] as usize;
288        let base_end = base.col_starts[j + 1] as usize;
289        out.row_indices
290            .extend_from_slice(&base.row_indices[base_start..base_end]);
291        out.values
292            .extend_from_slice(&base.values[base_start..base_end]);
293
294        let list_start = scratch.col_list_start[j] as usize;
295        let list_end = scratch.col_list_start[j + 1] as usize;
296        out.row_indices
297            .extend_from_slice(&scratch.col_list_row[list_start..list_end]);
298        out.values
299            .extend_from_slice(&scratch.col_list_val[list_start..list_end]);
300
301        let col_len = (base_end - base_start) + (list_end - list_start);
302        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
303        {
304            nz_cursor += col_len as i32;
305        }
306    }
307    out.col_starts.push(total_nnz_i32);
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::types::{RowBatch, SolverStatistics, StageTemplate};
314
315    /// Builds the canonical 3-col, 2-row fixture from `types.rs::make_fixture_stage_template`.
316    ///
317    /// ```
318    /// col_starts = [0, 2, 2, 3]
319    /// row_indices = [0, 1, 1]
320    /// values = [1.0, 2.0, 1.0]
321    /// row_lower = [6.0, 14.0]
322    /// row_upper = [6.0, 14.0]
323    /// ```
324    fn make_fixture_stage_template() -> StageTemplate {
325        StageTemplate {
326            num_cols: 3,
327            num_rows: 2,
328            num_nz: 3,
329            col_starts: vec![0_i32, 2, 2, 3],
330            row_indices: vec![0_i32, 1, 1],
331            values: vec![1.0, 2.0, 1.0],
332            col_lower: vec![0.0, 0.0, 0.0],
333            col_upper: vec![10.0, f64::INFINITY, 8.0],
334            objective: vec![0.0, 1.0, 50.0],
335            row_lower: vec![6.0, 14.0],
336            row_upper: vec![6.0, 14.0],
337            n_state: 1,
338            n_transfer: 0,
339            n_dual_relevant: 1,
340            n_hydro: 1,
341            max_par_order: 0,
342            col_scale: Vec::new(),
343            row_scale: Vec::new(),
344        }
345    }
346
347    /// Builds an empty [`RowBatch`] with zero rows (but valid `row_starts` sentinel).
348    fn make_empty_row_batch() -> RowBatch {
349        RowBatch {
350            num_rows: 0,
351            row_starts: vec![0_i32],
352            col_indices: vec![],
353            values: vec![],
354            row_lower: vec![],
355            row_upper: vec![],
356        }
357    }
358
359    // -----------------------------------------------------------------------
360    // Test 1: empty rows → structural copy of base
361    // -----------------------------------------------------------------------
362
363    #[test]
364    fn test_bake_empty_rows_copies_base() {
365        let base = make_fixture_stage_template();
366        let rows = make_empty_row_batch();
367        let mut out = StageTemplate::empty();
368
369        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
370
371        assert_eq!(out.num_cols, base.num_cols);
372        assert_eq!(out.num_rows, base.num_rows);
373        assert_eq!(out.num_nz, base.num_nz);
374        assert_eq!(out.col_starts, base.col_starts);
375        assert_eq!(out.row_indices, base.row_indices);
376        assert_eq!(out.values, base.values);
377        assert_eq!(out.col_lower, base.col_lower);
378        assert_eq!(out.col_upper, base.col_upper);
379        assert_eq!(out.objective, base.objective);
380        assert_eq!(out.row_lower, base.row_lower);
381        assert_eq!(out.row_upper, base.row_upper);
382        assert_eq!(out.n_state, base.n_state);
383        assert_eq!(out.n_transfer, base.n_transfer);
384        assert_eq!(out.n_dual_relevant, base.n_dual_relevant);
385        assert_eq!(out.n_hydro, base.n_hydro);
386        assert_eq!(out.max_par_order, base.max_par_order);
387        // empty row_scale stays empty
388        assert!(out.row_scale.is_empty());
389    }
390
391    // -----------------------------------------------------------------------
392    // Test 2: single appended row — exact CSC column layout
393    // -----------------------------------------------------------------------
394
395    #[test]
396    fn test_bake_single_row_appends_correct_column_entries() {
397        // Fixture: 3-col, 2-row base as described.
398        // RowBatch: one row touching cols 0 and 2.
399        let base = make_fixture_stage_template();
400        let rows = RowBatch {
401            num_rows: 1,
402            row_starts: vec![0_i32, 2],
403            col_indices: vec![0_i32, 2],
404            values: vec![-1.5, 1.0],
405            row_lower: vec![10.0],
406            row_upper: vec![f64::INFINITY],
407        };
408        let mut out = StageTemplate::empty();
409
410        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
411
412        assert_eq!(out.num_rows, 3);
413        assert_eq!(out.num_nz, 5);
414
415        // Corrected col_starts: col 1 has 0 entries.
416        assert_eq!(out.col_starts, vec![0_i32, 3, 3, 5]);
417
418        // Column 0: base rows [0,1] then cut row [2]
419        assert_eq!(&out.row_indices[0..3], &[0_i32, 1, 2]);
420        assert_eq!(&out.values[0..3], &[1.0_f64, 2.0, -1.5]);
421
422        // Column 1: no entries (col_starts[1]==3, col_starts[2]==3)
423
424        // Column 2: base row [1] then cut row [2]  (positions 3..5)
425        assert_eq!(&out.row_indices[3..5], &[1_i32, 2]);
426        assert_eq!(&out.values[3..5], &[1.0_f64, 1.0]);
427
428        // Row bounds
429        assert_eq!(out.row_lower, vec![6.0_f64, 14.0, 10.0]);
430        assert!(out.row_upper[2].is_infinite() && out.row_upper[2] > 0.0);
431    }
432
433    // -----------------------------------------------------------------------
434    // Test 3: non-empty row_scale — appended cut rows default to 1.0
435    // -----------------------------------------------------------------------
436
437    #[test]
438    fn test_bake_preserves_row_scale_and_defaults_cut_rows_to_one() {
439        let mut base = make_fixture_stage_template();
440        base.row_scale = vec![1.0, 2.0];
441
442        let rows = RowBatch {
443            num_rows: 1,
444            row_starts: vec![0_i32, 1],
445            col_indices: vec![0_i32],
446            values: vec![-1.0],
447            row_lower: vec![5.0],
448            row_upper: vec![f64::INFINITY],
449        };
450        let mut out = StageTemplate::empty();
451
452        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
453
454        assert_eq!(out.row_scale.len(), 3);
455        assert_eq!(out.row_scale[0], 1.0);
456        assert_eq!(out.row_scale[1], 2.0);
457        assert_eq!(out.row_scale[2], 1.0); // cut row implicit scale
458    }
459
460    // -----------------------------------------------------------------------
461    // Test 4: empty row_scale + zero rows → out.row_scale stays empty
462    // -----------------------------------------------------------------------
463
464    #[test]
465    fn test_bake_preserves_empty_row_scale_when_no_rows() {
466        let base = make_fixture_stage_template(); // row_scale is empty
467        let rows = make_empty_row_batch();
468        let mut out = StageTemplate::empty();
469
470        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
471
472        assert!(out.row_scale.is_empty());
473        assert_eq!(out.num_rows, base.num_rows);
474    }
475
476    // -----------------------------------------------------------------------
477    // Test 5: reuse out buffer — capacity must not decrease
478    // -----------------------------------------------------------------------
479
480    #[test]
481    fn test_bake_reuses_out_buffer_capacity() {
482        // First call: 5 rows, 10 nz (simulated by a larger base).
483        let big_base = StageTemplate {
484            num_cols: 2,
485            num_rows: 5,
486            num_nz: 10,
487            col_starts: vec![0_i32, 5, 10],
488            row_indices: vec![0_i32, 1, 2, 3, 4, 0, 1, 2, 3, 4],
489            values: vec![1.0; 10],
490            col_lower: vec![0.0, 0.0],
491            col_upper: vec![f64::INFINITY, f64::INFINITY],
492            objective: vec![1.0, 1.0],
493            row_lower: vec![0.0; 5],
494            row_upper: vec![f64::INFINITY; 5],
495            n_state: 0,
496            n_transfer: 0,
497            n_dual_relevant: 0,
498            n_hydro: 0,
499            max_par_order: 0,
500            col_scale: Vec::new(),
501            row_scale: Vec::new(),
502        };
503        let empty_rows = make_empty_row_batch();
504        let mut out = StageTemplate::empty();
505
506        bake_rows_into_template(
507            &big_base,
508            &empty_rows,
509            &mut out,
510            &mut BakingScratch::default(),
511        );
512
513        // Capture capacities after the first (larger) call.
514        let cap_col_starts = out.col_starts.capacity();
515        let cap_row_indices = out.row_indices.capacity();
516        let cap_values = out.values.capacity();
517        let cap_row_lower = out.row_lower.capacity();
518        let cap_row_upper = out.row_upper.capacity();
519
520        // Second call: 4 rows, 8 nz (smaller — must not reallocate downward).
521        let small_base = StageTemplate {
522            num_cols: 2,
523            num_rows: 4,
524            num_nz: 8,
525            col_starts: vec![0_i32, 4, 8],
526            row_indices: vec![0_i32, 1, 2, 3, 0, 1, 2, 3],
527            values: vec![1.0; 8],
528            col_lower: vec![0.0, 0.0],
529            col_upper: vec![f64::INFINITY, f64::INFINITY],
530            objective: vec![1.0, 1.0],
531            row_lower: vec![0.0; 4],
532            row_upper: vec![f64::INFINITY; 4],
533            n_state: 0,
534            n_transfer: 0,
535            n_dual_relevant: 0,
536            n_hydro: 0,
537            max_par_order: 0,
538            col_scale: Vec::new(),
539            row_scale: Vec::new(),
540        };
541
542        bake_rows_into_template(
543            &small_base,
544            &empty_rows,
545            &mut out,
546            &mut BakingScratch::default(),
547        );
548
549        assert_eq!(out.num_rows, 4);
550        assert_eq!(out.num_nz, 8);
551
552        // Capacities must not have decreased.
553        assert!(out.col_starts.capacity() >= cap_col_starts);
554        assert!(out.row_indices.capacity() >= cap_row_indices);
555        assert!(out.values.capacity() >= cap_values);
556        assert!(out.row_lower.capacity() >= cap_row_lower);
557        assert!(out.row_upper.capacity() >= cap_row_upper);
558    }
559
560    // -----------------------------------------------------------------------
561    // Test 6: determinism — two calls produce identical output
562    // -----------------------------------------------------------------------
563
564    #[test]
565    fn test_bake_determinism() {
566        let base = make_fixture_stage_template();
567        let rows = RowBatch {
568            num_rows: 2,
569            row_starts: vec![0_i32, 2, 3],
570            col_indices: vec![0_i32, 2, 1],
571            values: vec![-1.0, 0.5, 3.0],
572            row_lower: vec![8.0, 12.0],
573            row_upper: vec![f64::INFINITY, f64::INFINITY],
574        };
575
576        let mut out1 = StageTemplate::empty();
577        let mut out2 = StageTemplate::empty();
578
579        bake_rows_into_template(&base, &rows, &mut out1, &mut BakingScratch::default());
580        bake_rows_into_template(&base, &rows, &mut out2, &mut BakingScratch::default());
581
582        assert_eq!(out1.col_starts, out2.col_starts);
583        assert_eq!(out1.row_indices, out2.row_indices);
584        assert_eq!(out1.values, out2.values);
585        assert_eq!(out1.row_lower, out2.row_lower);
586        assert_eq!(out1.row_upper, out2.row_upper);
587    }
588
589    // -----------------------------------------------------------------------
590    // Test 7: multi-column distribution — 4-column base, 3 CSR rows
591    // -----------------------------------------------------------------------
592
593    #[test]
594    fn test_bake_multi_column_distribution() {
595        // 4-column, 3-row base:
596        //   col 0: rows [0,1]         values [1.0, 2.0]
597        //   col 1: row  [2]           value  [3.0]
598        //   col 2: rows [0,1,2]       values [4.0, 5.0, 6.0]
599        //   col 3: (empty)
600        // col_starts = [0, 2, 3, 6, 6]
601        let base = StageTemplate {
602            num_cols: 4,
603            num_rows: 3,
604            num_nz: 6,
605            col_starts: vec![0_i32, 2, 3, 6, 6],
606            row_indices: vec![0_i32, 1, 2, 0, 1, 2],
607            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
608            col_lower: vec![0.0; 4],
609            col_upper: vec![f64::INFINITY; 4],
610            objective: vec![0.0; 4],
611            row_lower: vec![0.0; 3],
612            row_upper: vec![f64::INFINITY; 3],
613            n_state: 2,
614            n_transfer: 1,
615            n_dual_relevant: 2,
616            n_hydro: 2,
617            max_par_order: 1,
618            col_scale: Vec::new(),
619            row_scale: Vec::new(),
620        };
621
622        // 3 CSR rows:
623        //   row 3 (cut 0): cols [0, 3]
624        //   row 4 (cut 1): cols [1, 2]
625        //   row 5 (cut 2): cols [0, 2, 3]
626        let rows = RowBatch {
627            num_rows: 3,
628            row_starts: vec![0_i32, 2, 4, 7],
629            col_indices: vec![0_i32, 3, 1, 2, 0, 2, 3],
630            values: vec![-1.0, 1.0, -2.0, 2.0, -3.0, 3.0, -4.0],
631            row_lower: vec![10.0, 20.0, 30.0],
632            row_upper: vec![f64::INFINITY; 3],
633        };
634
635        let mut out = StageTemplate::empty();
636        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
637
638        assert_eq!(out.num_rows, 6);
639        // col 0: 2 base + 2 cut (rows 3, 5) = 4
640        // col 1: 1 base + 1 cut (row 4) = 2
641        // col 2: 3 base + 2 cut (rows 4, 5) = 5
642        // col 3: 0 base + 2 cut (rows 3, 5) = 2
643        // total = 13
644        assert_eq!(out.num_nz, 13);
645        assert_eq!(out.col_starts, vec![0_i32, 4, 6, 11, 13]);
646
647        // Column 0 entries: base rows [0,1], then cut rows [3,5]
648        assert_eq!(&out.row_indices[0..4], &[0_i32, 1, 3, 5]);
649        assert_eq!(&out.values[0..4], &[1.0_f64, 2.0, -1.0, -3.0]);
650
651        // Column 1 entries: base row [2], then cut row [4]
652        assert_eq!(&out.row_indices[4..6], &[2_i32, 4]);
653        assert_eq!(&out.values[4..6], &[3.0_f64, -2.0]);
654
655        // Column 2 entries: base rows [0,1,2], then cut rows [4,5]
656        assert_eq!(&out.row_indices[6..11], &[0_i32, 1, 2, 4, 5]);
657        assert_eq!(&out.values[6..11], &[4.0_f64, 5.0, 6.0, 2.0, 3.0]);
658
659        // Column 3 entries: base (empty), then cut rows [3,5]
660        assert_eq!(&out.row_indices[11..13], &[3_i32, 5]);
661        assert_eq!(&out.values[11..13], &[1.0_f64, -4.0]);
662
663        // Row bounds
664        assert_eq!(&out.row_lower[3..6], &[10.0_f64, 20.0, 30.0]);
665    }
666
667    // -----------------------------------------------------------------------
668    // Test 8: MockSolver records num_rows from load_model
669    // -----------------------------------------------------------------------
670
671    /// A minimal [`crate::SolverInterface`] implementation that records the
672    /// `num_rows` value from the most recent `load_model` call.
673    struct MockSolver {
674        last_loaded_num_rows: usize,
675        stats: SolverStatistics,
676    }
677
678    impl MockSolver {
679        fn new() -> Self {
680            Self {
681                last_loaded_num_rows: 0,
682                stats: SolverStatistics::default(),
683            }
684        }
685    }
686
687    impl crate::SolverInterface for MockSolver {
688        type Profile = crate::profile::MockProfile;
689
690        fn apply_profile(&mut self, _profile: &crate::profile::MockProfile) {}
691
692        fn load_model(&mut self, template: &StageTemplate) {
693            self.last_loaded_num_rows = template.num_rows;
694            self.stats.load_model_count += 1;
695        }
696
697        fn add_rows(&mut self, _rows: &RowBatch) {}
698
699        fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
700
701        fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
702
703        fn solve(
704            &mut self,
705            _basis: Option<&crate::types::Basis>,
706        ) -> Result<crate::types::SolutionView<'_>, crate::types::SolverError> {
707            Err(crate::types::SolverError::InternalError {
708                message: "mock".to_string(),
709                error_code: None,
710            })
711        }
712
713        fn get_basis(&mut self, _out: &mut crate::types::Basis) {}
714
715        fn statistics(&self) -> SolverStatistics {
716            self.stats.clone()
717        }
718
719        fn statistics_into(&self, out: &mut SolverStatistics) {
720            out.copy_from(&self.stats);
721        }
722
723        fn name(&self) -> &'static str {
724            "Mock"
725        }
726
727        fn solver_name_version(&self) -> String {
728            "MockSolver 0.0.0".to_string()
729        }
730    }
731
732    #[test]
733    fn test_bake_load_model_row_count() {
734        use crate::SolverInterface;
735
736        let base = make_fixture_stage_template();
737        let rows = RowBatch {
738            num_rows: 3,
739            row_starts: vec![0_i32, 1, 2, 3],
740            col_indices: vec![0_i32, 1, 2],
741            values: vec![-1.0, -1.0, -1.0],
742            row_lower: vec![5.0, 6.0, 7.0],
743            row_upper: vec![f64::INFINITY; 3],
744        };
745
746        let mut out = StageTemplate::empty();
747        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
748
749        let expected_rows = base.num_rows + rows.num_rows; // 2 + 3 = 5
750
751        let mut solver = MockSolver::new();
752        let before = solver.statistics().load_model_count;
753        solver.load_model(&out);
754        let after = solver.statistics().load_model_count;
755
756        assert_eq!(after - before, 1);
757        assert_eq!(solver.last_loaded_num_rows, expected_rows);
758    }
759
760    // -----------------------------------------------------------------------
761    // Test (extra): empty row_scale + non-zero rows → appended 1.0 entries only
762    // -----------------------------------------------------------------------
763
764    #[test]
765    fn test_bake_empty_base_row_scale_with_cut_rows_appends_ones() {
766        let base = make_fixture_stage_template(); // row_scale is empty
767        let rows = RowBatch {
768            num_rows: 2,
769            row_starts: vec![0_i32, 1, 2],
770            col_indices: vec![0_i32, 0],
771            values: vec![-1.0, -2.0],
772            row_lower: vec![5.0, 6.0],
773            row_upper: vec![f64::INFINITY; 2],
774        };
775        let mut out = StageTemplate::empty();
776
777        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
778
779        // StageTemplate invariant: when non-empty, row_scale.len() == num_rows.
780        // base.row_scale was empty but rows.num_rows == 2, so the baked template
781        // materialises a full row_scale of 1.0 (base.num_rows + rows.num_rows == 4).
782        assert_eq!(out.row_scale.len(), base.num_rows + rows.num_rows);
783        assert!(out.row_scale.iter().all(|&s| s == 1.0));
784    }
785
786    // -----------------------------------------------------------------------
787    // Test: i32::MAX overflow panics
788    //
789    // Reaching the `i32::try_from` check without OOM requires lying about
790    // `num_nz` while keeping the backing Vecs empty. In debug builds,
791    // `debug_assert!(base.row_indices.len() == base.num_nz)` fires before the
792    // overflow check, so the test is restricted to `#[cfg(not(debug_assertions))]`
793    // (i.e., `cargo test --release`).
794    // -----------------------------------------------------------------------
795
796    /// Verifies that `bake_rows_into_template` panics with the expected message
797    /// when `base.num_nz + rows_nnz` exceeds `i32::MAX`.
798    ///
799    /// Skipped in debug builds because `debug_assert!` on `base.row_indices.len()`
800    /// fires before the overflow guard when `num_nz` is fabricated. The
801    /// `i32::try_from` path exists in both builds; run `cargo test --release` to
802    /// exercise it directly.
803    #[test]
804    #[cfg(not(debug_assertions))]
805    #[should_panic(expected = "total nnz exceeds i32::MAX")]
806    fn test_bake_panics_on_nnz_overflow() {
807        // base: zero columns, zero actual non-zeros, but num_nz = i32::MAX.
808        // In release mode, debug_asserts are disabled so the i32::try_from guard
809        // is reached before any length check. rows contributes 1 extra non-zero
810        // (rows_nnz = 1), making total_nnz = i32::MAX + 1 which overflows i32.
811        let large_num_nz = usize::try_from(i32::MAX).unwrap(); // 2_147_483_647
812        let base = StageTemplate {
813            num_cols: 0,
814            num_rows: 0,
815            num_nz: large_num_nz,
816            col_starts: vec![0_i32], // len = num_cols + 1 = 1
817            row_indices: vec![],     // empty — debug_asserts disabled in release
818            values: vec![],
819            col_lower: vec![],
820            col_upper: vec![],
821            objective: vec![],
822            row_lower: vec![],
823            row_upper: vec![],
824            n_state: 0,
825            n_transfer: 0,
826            n_dual_relevant: 0,
827            n_hydro: 0,
828            max_par_order: 0,
829            col_scale: Vec::new(),
830            row_scale: Vec::new(),
831        };
832        // rows contributes 1 non-zero, tipping base.num_nz + 1 > i32::MAX.
833        // col_indices = [0] would be out-of-range for num_cols == 0, but the
834        // corresponding debug_assert is also disabled in release mode; the
835        // i32::try_from check fires first because total_nnz is computed before
836        // any further use of col_indices.
837        let rows = RowBatch {
838            num_rows: 1,
839            row_starts: vec![0_i32, 1],
840            col_indices: vec![0_i32],
841            values: vec![1.0],
842            row_lower: vec![0.0],
843            row_upper: vec![f64::INFINITY],
844        };
845        let mut out = StageTemplate::empty();
846        bake_rows_into_template(&base, &rows, &mut out, &mut BakingScratch::default());
847    }
848
849    // -----------------------------------------------------------------------
850    // Test: reusing one scratch across two bakes is bit-identical to a fresh
851    // scratch, and the scratch buffers never realloc downward on reuse.
852    // -----------------------------------------------------------------------
853
854    #[test]
855    fn bake_twice_same_scratch_is_bit_identical() {
856        // Test 7 multi-column fixture: 4-col base, 3 CSR rows.
857        let base = StageTemplate {
858            num_cols: 4,
859            num_rows: 3,
860            num_nz: 6,
861            col_starts: vec![0_i32, 2, 3, 6, 6],
862            row_indices: vec![0_i32, 1, 2, 0, 1, 2],
863            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
864            col_lower: vec![0.0; 4],
865            col_upper: vec![f64::INFINITY; 4],
866            objective: vec![0.0; 4],
867            row_lower: vec![0.0; 3],
868            row_upper: vec![f64::INFINITY; 3],
869            n_state: 2,
870            n_transfer: 1,
871            n_dual_relevant: 2,
872            n_hydro: 2,
873            max_par_order: 1,
874            col_scale: Vec::new(),
875            row_scale: Vec::new(),
876        };
877        let rows = RowBatch {
878            num_rows: 3,
879            row_starts: vec![0_i32, 2, 4, 7],
880            col_indices: vec![0_i32, 3, 1, 2, 0, 2, 3],
881            values: vec![-1.0, 1.0, -2.0, 2.0, -3.0, 3.0, -4.0],
882            row_lower: vec![10.0, 20.0, 30.0],
883            row_upper: vec![f64::INFINITY; 3],
884        };
885
886        let mut scratch = BakingScratch::default();
887
888        // First reused-scratch bake.
889        let mut out1 = StageTemplate::empty();
890        bake_rows_into_template(&base, &rows, &mut out1, &mut scratch);
891
892        // Capacities after the first bake (must not shrink on reuse).
893        let cap_cut_nz = scratch.cut_nz_per_col.capacity();
894        let cap_col_start = scratch.col_list_start.capacity();
895        let cap_col_row = scratch.col_list_row.capacity();
896        let cap_col_val = scratch.col_list_val.capacity();
897        let cap_write_cursor = scratch.write_cursor.capacity();
898
899        // Second bake of the SAME inputs reusing the SAME scratch.
900        let mut out2 = StageTemplate::empty();
901        bake_rows_into_template(&base, &rows, &mut out2, &mut scratch);
902
903        // Reference bake with a fresh scratch.
904        let mut out_fresh = StageTemplate::empty();
905        bake_rows_into_template(&base, &rows, &mut out_fresh, &mut BakingScratch::default());
906
907        // Bit-identical outputs across reuse and fresh scratch.
908        assert_eq!(out1.col_starts, out2.col_starts);
909        assert_eq!(out1.col_starts, out_fresh.col_starts);
910        assert_eq!(out1.row_indices, out2.row_indices);
911        assert_eq!(out1.row_indices, out_fresh.row_indices);
912        assert_eq!(out1.values, out2.values);
913        assert_eq!(out1.values, out_fresh.values);
914        assert_eq!(out1.row_lower, out2.row_lower);
915        assert_eq!(out1.row_lower, out_fresh.row_lower);
916        assert_eq!(out1.row_upper, out2.row_upper);
917        assert_eq!(out1.row_upper, out_fresh.row_upper);
918        assert_eq!(out1.row_scale, out2.row_scale);
919        assert_eq!(out1.row_scale, out_fresh.row_scale);
920
921        // No downward realloc on the second (reused) bake.
922        assert!(scratch.cut_nz_per_col.capacity() >= cap_cut_nz);
923        assert!(scratch.col_list_start.capacity() >= cap_col_start);
924        assert!(scratch.col_list_row.capacity() >= cap_col_row);
925        assert!(scratch.col_list_val.capacity() >= cap_col_val);
926        assert!(scratch.write_cursor.capacity() >= cap_write_cursor);
927    }
928}