use crate::types::{RowBatch, StageTemplate};
#[allow(clippy::too_many_lines)] pub fn bake_rows_into_template(base: &StageTemplate, rows: &RowBatch, out: &mut StageTemplate) {
#[allow(clippy::cast_sign_loss)]
{
debug_assert_eq!(
base.col_starts.len(),
base.num_cols + 1,
"base.col_starts.len()={} but num_cols+1={}",
base.col_starts.len(),
base.num_cols + 1
);
debug_assert_eq!(
base.col_starts.last().copied().unwrap_or(0) as usize,
base.num_nz,
"base.col_starts[num_cols] != base.num_nz"
);
debug_assert_eq!(base.row_indices.len(), base.num_nz);
debug_assert_eq!(base.values.len(), base.num_nz);
debug_assert_eq!(base.col_lower.len(), base.num_cols);
debug_assert_eq!(base.col_upper.len(), base.num_cols);
debug_assert_eq!(base.objective.len(), base.num_cols);
debug_assert_eq!(base.row_lower.len(), base.num_rows);
debug_assert_eq!(base.row_upper.len(), base.num_rows);
debug_assert!(
base.col_scale.is_empty() || base.col_scale.len() == base.num_cols,
"base.col_scale must be empty or length num_cols"
);
debug_assert!(
base.row_scale.is_empty() || base.row_scale.len() == base.num_rows,
"base.row_scale must be empty or length num_rows"
);
if rows.num_rows > 0 {
debug_assert_eq!(
rows.row_starts.len(),
rows.num_rows + 1,
"rows.row_starts.len()={} but num_rows+1={}",
rows.row_starts.len(),
rows.num_rows + 1
);
debug_assert_eq!(
rows.row_starts[0], 0,
"RowBatch invariant: row_starts[0] must be 0"
);
debug_assert_eq!(rows.row_lower.len(), rows.num_rows);
debug_assert_eq!(rows.row_upper.len(), rows.num_rows);
let rows_nnz = rows.row_starts[rows.num_rows] as usize;
debug_assert_eq!(rows.col_indices.len(), rows_nnz);
debug_assert_eq!(rows.values.len(), rows_nnz);
#[cfg(debug_assertions)]
for &col in &rows.col_indices {
debug_assert!(
(col as usize) < base.num_cols,
"col_indices[k]={col} >= base.num_cols={}",
base.num_cols
);
}
}
}
#[allow(clippy::cast_sign_loss)]
let rows_nnz = if rows.num_rows > 0 {
rows.row_starts[rows.num_rows] as usize
} else {
0
};
let total_nnz = base.num_nz + rows_nnz;
#[allow(clippy::expect_used)]
let total_nnz_i32 = i32::try_from(total_nnz).expect("total nnz exceeds i32::MAX");
let num_cols = base.num_cols;
let num_rows = base.num_rows + rows.num_rows;
let mut cut_nz_per_col: Vec<u32> = vec![0u32; num_cols];
#[allow(clippy::cast_sign_loss)]
for &col in &rows.col_indices {
cut_nz_per_col[col as usize] += 1;
}
out.col_starts.clear();
out.row_indices.clear();
out.values.clear();
out.col_lower.clear();
out.col_upper.clear();
out.objective.clear();
out.col_scale.clear();
out.row_lower.clear();
out.row_upper.clear();
out.row_scale.clear();
out.num_cols = num_cols;
out.num_rows = num_rows;
out.num_nz = total_nnz;
out.n_state = base.n_state;
out.n_transfer = base.n_transfer;
out.n_dual_relevant = base.n_dual_relevant;
out.n_hydro = base.n_hydro;
out.max_par_order = base.max_par_order;
out.col_lower.extend_from_slice(&base.col_lower);
out.col_upper.extend_from_slice(&base.col_upper);
out.objective.extend_from_slice(&base.objective);
out.col_scale.extend_from_slice(&base.col_scale);
out.row_lower.extend_from_slice(&base.row_lower);
out.row_lower.extend_from_slice(&rows.row_lower);
out.row_upper.extend_from_slice(&base.row_upper);
out.row_upper.extend_from_slice(&rows.row_upper);
if !base.row_scale.is_empty() {
out.row_scale.extend_from_slice(&base.row_scale);
out.row_scale
.resize(out.row_scale.len() + rows.num_rows, 1.0_f64);
} else if rows.num_rows > 0 {
out.row_scale.resize(base.num_rows + rows.num_rows, 1.0_f64);
}
let mut col_list_start: Vec<u32> = Vec::with_capacity(num_cols + 1);
let mut running = 0u32;
for &count in &cut_nz_per_col {
col_list_start.push(running);
running += count;
}
col_list_start.push(running);
let mut col_list_row: Vec<i32> = vec![0i32; rows_nnz];
let mut col_list_val: Vec<f64> = vec![0.0f64; rows_nnz];
let mut write_cursor: Vec<u32> = vec![0u32; num_cols];
#[allow(clippy::cast_sign_loss)]
for r in 0..rows.num_rows {
let start = rows.row_starts[r] as usize;
let end = rows.row_starts[r + 1] as usize;
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let row_i32 = (base.num_rows + r) as i32;
for k in start..end {
let j = rows.col_indices[k] as usize;
let pos = (col_list_start[j] + write_cursor[j]) as usize;
col_list_row[pos] = row_i32;
col_list_val[pos] = rows.values[k];
write_cursor[j] += 1;
}
}
let mut nz_cursor: i32 = 0;
#[allow(clippy::cast_sign_loss)]
for j in 0..num_cols {
out.col_starts.push(nz_cursor);
let base_start = base.col_starts[j] as usize;
let base_end = base.col_starts[j + 1] as usize;
out.row_indices
.extend_from_slice(&base.row_indices[base_start..base_end]);
out.values
.extend_from_slice(&base.values[base_start..base_end]);
let list_start = col_list_start[j] as usize;
let list_end = col_list_start[j + 1] as usize;
out.row_indices
.extend_from_slice(&col_list_row[list_start..list_end]);
out.values
.extend_from_slice(&col_list_val[list_start..list_end]);
let col_len = (base_end - base_start) + (list_end - list_start);
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
{
nz_cursor += col_len as i32;
}
}
out.col_starts.push(total_nnz_i32);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{RowBatch, SolverStatistics, StageTemplate};
fn make_fixture_stage_template() -> StageTemplate {
StageTemplate {
num_cols: 3,
num_rows: 2,
num_nz: 3,
col_starts: vec![0_i32, 2, 2, 3],
row_indices: vec![0_i32, 1, 1],
values: vec![1.0, 2.0, 1.0],
col_lower: vec![0.0, 0.0, 0.0],
col_upper: vec![10.0, f64::INFINITY, 8.0],
objective: vec![0.0, 1.0, 50.0],
row_lower: vec![6.0, 14.0],
row_upper: vec![6.0, 14.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_empty_row_batch() -> RowBatch {
RowBatch {
num_rows: 0,
row_starts: vec![0_i32],
col_indices: vec![],
values: vec![],
row_lower: vec![],
row_upper: vec![],
}
}
#[test]
fn test_bake_empty_rows_copies_base() {
let base = make_fixture_stage_template();
let rows = make_empty_row_batch();
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert_eq!(out.num_cols, base.num_cols);
assert_eq!(out.num_rows, base.num_rows);
assert_eq!(out.num_nz, base.num_nz);
assert_eq!(out.col_starts, base.col_starts);
assert_eq!(out.row_indices, base.row_indices);
assert_eq!(out.values, base.values);
assert_eq!(out.col_lower, base.col_lower);
assert_eq!(out.col_upper, base.col_upper);
assert_eq!(out.objective, base.objective);
assert_eq!(out.row_lower, base.row_lower);
assert_eq!(out.row_upper, base.row_upper);
assert_eq!(out.n_state, base.n_state);
assert_eq!(out.n_transfer, base.n_transfer);
assert_eq!(out.n_dual_relevant, base.n_dual_relevant);
assert_eq!(out.n_hydro, base.n_hydro);
assert_eq!(out.max_par_order, base.max_par_order);
assert!(out.row_scale.is_empty());
}
#[test]
fn test_bake_single_row_appends_correct_column_entries() {
let base = make_fixture_stage_template();
let rows = RowBatch {
num_rows: 1,
row_starts: vec![0_i32, 2],
col_indices: vec![0_i32, 2],
values: vec![-1.5, 1.0],
row_lower: vec![10.0],
row_upper: vec![f64::INFINITY],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert_eq!(out.num_rows, 3);
assert_eq!(out.num_nz, 5);
assert_eq!(out.col_starts, vec![0_i32, 3, 3, 5]);
assert_eq!(&out.row_indices[0..3], &[0_i32, 1, 2]);
assert_eq!(&out.values[0..3], &[1.0_f64, 2.0, -1.5]);
assert_eq!(&out.row_indices[3..5], &[1_i32, 2]);
assert_eq!(&out.values[3..5], &[1.0_f64, 1.0]);
assert_eq!(out.row_lower, vec![6.0_f64, 14.0, 10.0]);
assert!(out.row_upper[2].is_infinite() && out.row_upper[2] > 0.0);
}
#[test]
fn test_bake_preserves_row_scale_and_defaults_cut_rows_to_one() {
let mut base = make_fixture_stage_template();
base.row_scale = vec![1.0, 2.0];
let rows = RowBatch {
num_rows: 1,
row_starts: vec![0_i32, 1],
col_indices: vec![0_i32],
values: vec![-1.0],
row_lower: vec![5.0],
row_upper: vec![f64::INFINITY],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert_eq!(out.row_scale.len(), 3);
assert_eq!(out.row_scale[0], 1.0);
assert_eq!(out.row_scale[1], 2.0);
assert_eq!(out.row_scale[2], 1.0); }
#[test]
fn test_bake_preserves_empty_row_scale_when_no_rows() {
let base = make_fixture_stage_template(); let rows = make_empty_row_batch();
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert!(out.row_scale.is_empty());
assert_eq!(out.num_rows, base.num_rows);
}
#[test]
fn test_bake_reuses_out_buffer_capacity() {
let big_base = StageTemplate {
num_cols: 2,
num_rows: 5,
num_nz: 10,
col_starts: vec![0_i32, 5, 10],
row_indices: vec![0_i32, 1, 2, 3, 4, 0, 1, 2, 3, 4],
values: vec![1.0; 10],
col_lower: vec![0.0, 0.0],
col_upper: vec![f64::INFINITY, f64::INFINITY],
objective: vec![1.0, 1.0],
row_lower: vec![0.0; 5],
row_upper: vec![f64::INFINITY; 5],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
};
let empty_rows = make_empty_row_batch();
let mut out = StageTemplate::empty();
bake_rows_into_template(&big_base, &empty_rows, &mut out);
let cap_col_starts = out.col_starts.capacity();
let cap_row_indices = out.row_indices.capacity();
let cap_values = out.values.capacity();
let cap_row_lower = out.row_lower.capacity();
let cap_row_upper = out.row_upper.capacity();
let small_base = StageTemplate {
num_cols: 2,
num_rows: 4,
num_nz: 8,
col_starts: vec![0_i32, 4, 8],
row_indices: vec![0_i32, 1, 2, 3, 0, 1, 2, 3],
values: vec![1.0; 8],
col_lower: vec![0.0, 0.0],
col_upper: vec![f64::INFINITY, f64::INFINITY],
objective: vec![1.0, 1.0],
row_lower: vec![0.0; 4],
row_upper: vec![f64::INFINITY; 4],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
};
bake_rows_into_template(&small_base, &empty_rows, &mut out);
assert_eq!(out.num_rows, 4);
assert_eq!(out.num_nz, 8);
assert!(out.col_starts.capacity() >= cap_col_starts);
assert!(out.row_indices.capacity() >= cap_row_indices);
assert!(out.values.capacity() >= cap_values);
assert!(out.row_lower.capacity() >= cap_row_lower);
assert!(out.row_upper.capacity() >= cap_row_upper);
}
#[test]
fn test_bake_determinism() {
let base = make_fixture_stage_template();
let rows = RowBatch {
num_rows: 2,
row_starts: vec![0_i32, 2, 3],
col_indices: vec![0_i32, 2, 1],
values: vec![-1.0, 0.5, 3.0],
row_lower: vec![8.0, 12.0],
row_upper: vec![f64::INFINITY, f64::INFINITY],
};
let mut out1 = StageTemplate::empty();
let mut out2 = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out1);
bake_rows_into_template(&base, &rows, &mut out2);
assert_eq!(out1.col_starts, out2.col_starts);
assert_eq!(out1.row_indices, out2.row_indices);
assert_eq!(out1.values, out2.values);
assert_eq!(out1.row_lower, out2.row_lower);
assert_eq!(out1.row_upper, out2.row_upper);
}
#[test]
fn test_bake_multi_column_distribution() {
let base = StageTemplate {
num_cols: 4,
num_rows: 3,
num_nz: 6,
col_starts: vec![0_i32, 2, 3, 6, 6],
row_indices: vec![0_i32, 1, 2, 0, 1, 2],
values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
col_lower: vec![0.0; 4],
col_upper: vec![f64::INFINITY; 4],
objective: vec![0.0; 4],
row_lower: vec![0.0; 3],
row_upper: vec![f64::INFINITY; 3],
n_state: 2,
n_transfer: 1,
n_dual_relevant: 2,
n_hydro: 2,
max_par_order: 1,
col_scale: Vec::new(),
row_scale: Vec::new(),
};
let rows = RowBatch {
num_rows: 3,
row_starts: vec![0_i32, 2, 4, 7],
col_indices: vec![0_i32, 3, 1, 2, 0, 2, 3],
values: vec![-1.0, 1.0, -2.0, 2.0, -3.0, 3.0, -4.0],
row_lower: vec![10.0, 20.0, 30.0],
row_upper: vec![f64::INFINITY; 3],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert_eq!(out.num_rows, 6);
assert_eq!(out.num_nz, 13);
assert_eq!(out.col_starts, vec![0_i32, 4, 6, 11, 13]);
assert_eq!(&out.row_indices[0..4], &[0_i32, 1, 3, 5]);
assert_eq!(&out.values[0..4], &[1.0_f64, 2.0, -1.0, -3.0]);
assert_eq!(&out.row_indices[4..6], &[2_i32, 4]);
assert_eq!(&out.values[4..6], &[3.0_f64, -2.0]);
assert_eq!(&out.row_indices[6..11], &[0_i32, 1, 2, 4, 5]);
assert_eq!(&out.values[6..11], &[4.0_f64, 5.0, 6.0, 2.0, 3.0]);
assert_eq!(&out.row_indices[11..13], &[3_i32, 5]);
assert_eq!(&out.values[11..13], &[1.0_f64, -4.0]);
assert_eq!(&out.row_lower[3..6], &[10.0_f64, 20.0, 30.0]);
}
struct MockSolver {
last_loaded_num_rows: usize,
stats: SolverStatistics,
}
impl MockSolver {
fn new() -> Self {
Self {
last_loaded_num_rows: 0,
stats: SolverStatistics::default(),
}
}
}
impl crate::SolverInterface for MockSolver {
type Profile = crate::HighsProfile;
fn apply_profile(&mut self, _profile: &crate::HighsProfile) {}
fn load_model(&mut self, template: &StageTemplate) {
self.last_loaded_num_rows = template.num_rows;
self.stats.load_model_count += 1;
}
fn add_rows(&mut self, _rows: &RowBatch) {}
fn set_row_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
fn set_col_bounds(&mut self, _indices: &[usize], _lower: &[f64], _upper: &[f64]) {}
fn solve(
&mut self,
_basis: Option<&crate::types::Basis>,
) -> Result<crate::types::SolutionView<'_>, crate::types::SolverError> {
Err(crate::types::SolverError::InternalError {
message: "mock".to_string(),
error_code: None,
})
}
fn get_basis(&mut self, _out: &mut crate::types::Basis) {}
fn statistics(&self) -> SolverStatistics {
self.stats.clone()
}
fn name(&self) -> &'static str {
"Mock"
}
fn solver_name_version(&self) -> String {
"MockSolver 0.0.0".to_string()
}
fn set_primal_feasibility_tolerance(&mut self, _value: f64) {}
fn set_dual_feasibility_tolerance(&mut self, _value: f64) {}
fn set_simplex_iteration_limit_profile(&mut self, _value: u32) {}
fn set_ipm_iteration_limit_profile(&mut self, _value: u32) {}
}
#[test]
fn test_bake_load_model_row_count() {
use crate::SolverInterface;
let base = make_fixture_stage_template();
let rows = RowBatch {
num_rows: 3,
row_starts: vec![0_i32, 1, 2, 3],
col_indices: vec![0_i32, 1, 2],
values: vec![-1.0, -1.0, -1.0],
row_lower: vec![5.0, 6.0, 7.0],
row_upper: vec![f64::INFINITY; 3],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
let expected_rows = base.num_rows + rows.num_rows;
let mut solver = MockSolver::new();
let before = solver.statistics().load_model_count;
solver.load_model(&out);
let after = solver.statistics().load_model_count;
assert_eq!(after - before, 1);
assert_eq!(solver.last_loaded_num_rows, expected_rows);
}
#[test]
fn test_bake_empty_base_row_scale_with_cut_rows_appends_ones() {
let base = make_fixture_stage_template(); let rows = RowBatch {
num_rows: 2,
row_starts: vec![0_i32, 1, 2],
col_indices: vec![0_i32, 0],
values: vec![-1.0, -2.0],
row_lower: vec![5.0, 6.0],
row_upper: vec![f64::INFINITY; 2],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
assert_eq!(out.row_scale.len(), base.num_rows + rows.num_rows);
assert!(out.row_scale.iter().all(|&s| s == 1.0));
}
#[test]
#[cfg(not(debug_assertions))]
#[should_panic(expected = "total nnz exceeds i32::MAX")]
fn test_bake_panics_on_nnz_overflow() {
let large_num_nz = usize::try_from(i32::MAX).unwrap(); let base = StageTemplate {
num_cols: 0,
num_rows: 0,
num_nz: large_num_nz,
col_starts: vec![0_i32], row_indices: vec![], values: vec![],
col_lower: vec![],
col_upper: vec![],
objective: vec![],
row_lower: vec![],
row_upper: vec![],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
};
let rows = RowBatch {
num_rows: 1,
row_starts: vec![0_i32, 1],
col_indices: vec![0_i32],
values: vec![1.0],
row_lower: vec![0.0],
row_upper: vec![f64::INFINITY],
};
let mut out = StageTemplate::empty();
bake_rows_into_template(&base, &rows, &mut out);
}
}