use crate::variables::{VarId, Vars, Val};
use crate::variables::views::Context;
use crate::lpsolver::{LpProblem, LpSolution, LpStatus};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConstraintRelation {
Equality,
LessOrEqual,
GreaterOrEqual,
}
#[derive(Debug, Clone)]
pub struct LinearConstraint {
pub coefficients: Vec<f64>,
pub variables: Vec<VarId>,
pub relation: ConstraintRelation,
pub rhs: f64,
}
impl LinearConstraint {
pub fn new(
coefficients: Vec<f64>,
variables: Vec<VarId>,
relation: ConstraintRelation,
rhs: f64,
) -> Self {
debug_assert_eq!(
coefficients.len(),
variables.len(),
"Coefficients and variables must have same length"
);
Self {
coefficients,
variables,
relation,
rhs,
}
}
pub fn equality(coefficients: Vec<f64>, variables: Vec<VarId>, rhs: f64) -> Self {
Self::new(coefficients, variables, ConstraintRelation::Equality, rhs)
}
pub fn less_or_equal(coefficients: Vec<f64>, variables: Vec<VarId>, rhs: f64) -> Self {
Self::new(coefficients, variables, ConstraintRelation::LessOrEqual, rhs)
}
pub fn greater_or_equal(coefficients: Vec<f64>, variables: Vec<VarId>, rhs: f64) -> Self {
Self::new(coefficients, variables, ConstraintRelation::GreaterOrEqual, rhs)
}
pub fn to_standard_form(&self) -> Vec<(Vec<f64>, f64)> {
match self.relation {
ConstraintRelation::LessOrEqual => {
vec![(self.coefficients.clone(), self.rhs)]
}
ConstraintRelation::GreaterOrEqual => {
let neg_coeffs: Vec<f64> = self.coefficients.iter().map(|&c| -c).collect();
vec![(neg_coeffs, -self.rhs)]
}
ConstraintRelation::Equality => {
let neg_coeffs: Vec<f64> = self.coefficients.iter().map(|&c| -c).collect();
vec![
(self.coefficients.clone(), self.rhs),
(neg_coeffs, -self.rhs),
]
}
}
}
}
#[derive(Debug, Clone)]
pub struct LinearConstraintSystem {
pub variables: Vec<VarId>,
pub constraints: Vec<LinearConstraint>,
pub objective: Option<LinearObjective>,
}
#[derive(Debug, Clone)]
pub struct LinearObjective {
pub coefficients: Vec<f64>,
pub minimize: bool,
}
impl LinearConstraintSystem {
pub fn new() -> Self {
Self {
variables: Vec::new(),
constraints: Vec::new(),
objective: None,
}
}
pub fn add_constraint(&mut self, constraint: LinearConstraint) {
for &var in &constraint.variables {
if !self.variables.contains(&var) {
self.variables.push(var);
}
}
self.constraints.push(constraint);
}
pub fn set_objective(&mut self, coefficients: Vec<f64>, minimize: bool) {
assert_eq!(coefficients.len(), self.variables.len(),
"Objective coefficients must match number of variables");
self.objective = Some(LinearObjective { coefficients, minimize });
}
pub fn is_suitable_for_lp(&self, _vars: &Vars) -> bool {
!self.constraints.is_empty() && self.variables.len() >= 2
}
pub fn to_lp_problem(&self, vars: &Vars) -> LpProblem {
let tolerance = crate::lpsolver::LpConfig::default().feasibility_tol;
let mut var_to_lp_index: std::collections::HashMap<VarId, usize> = std::collections::HashMap::new();
let mut lp_index_to_var: Vec<VarId> = Vec::new();
let mut constants: std::collections::HashMap<VarId, f64> = std::collections::HashMap::new();
for &var in &self.variables {
let (lower, upper) = extract_bounds(var, vars);
if (upper - lower).abs() < tolerance {
constants.insert(var, lower);
} else {
let lp_idx = lp_index_to_var.len();
var_to_lp_index.insert(var, lp_idx);
lp_index_to_var.push(var);
}
}
let n_vars = lp_index_to_var.len();
let c = if let Some(ref obj) = self.objective {
let mut obj_vec = vec![0.0; n_vars];
for (sys_idx, &var) in self.variables.iter().enumerate() {
if let Some(&lp_idx) = var_to_lp_index.get(&var) {
if sys_idx < obj.coefficients.len() {
obj_vec[lp_idx] = if obj.minimize {
-obj.coefficients[sys_idx]
} else {
obj.coefficients[sys_idx]
};
}
}
}
obj_vec
} else {
vec![0.0; n_vars]
};
let estimated_rows = self.constraints.len() * 2;
let mut a = Vec::with_capacity(estimated_rows);
let mut b = Vec::with_capacity(estimated_rows);
if self.constraints.len() > 20 {
lp_debug!("LP BUILD: Processing {} constraints with {} variables (output suppressed for performance)...",
self.constraints.len(), n_vars);
}
for constraint in &self.constraints {
if self.constraints.len() <= 20 {
lp_debug!("LP BUILD: Converting constraint with {} vars, relation {:?}, rhs {}",
constraint.variables.len(), constraint.relation, constraint.rhs);
}
for (std_coeffs, std_rhs) in constraint.to_standard_form() {
let mut row = vec![0.0; n_vars];
let mut rhs_adjusted = std_rhs;
for (j, &var) in constraint.variables.iter().enumerate() {
let coeff = std_coeffs[j];
if let Some(&const_val) = constants.get(&var) {
rhs_adjusted -= coeff * const_val;
if self.constraints.len() <= 20 {
lp_debug!("LP BUILD: var {:?} is constant = {}, adjusting RHS by -{} * {} = {}",
var, const_val, coeff, const_val, -coeff * const_val);
}
} else if let Some(&lp_idx) = var_to_lp_index.get(&var) {
row[lp_idx] = coeff;
}
}
if self.constraints.len() <= 20 {
lp_debug!("LP BUILD: Constraint row = {:?}, rhs = {}", row, rhs_adjusted);
}
a.push(row);
b.push(rhs_adjusted);
}
}
let lower_bounds: Vec<f64> = lp_index_to_var.iter()
.map(|&v| extract_bounds(v, vars).0)
.collect();
let upper_bounds: Vec<f64> = lp_index_to_var.iter()
.map(|&v| extract_bounds(v, vars).1)
.collect();
let n_constraints = a.len();
lp_debug!("LP BUILD: Final problem: {} variables (excluding {} constants), {} constraints",
n_vars, constants.len(), n_constraints);
LpProblem::new(n_vars, n_constraints, c, a, b, lower_bounds, upper_bounds)
}
pub fn n_variables(&self) -> usize {
self.variables.len()
}
pub fn n_constraints(&self) -> usize {
self.constraints.len()
}
}
impl Default for LinearConstraintSystem {
fn default() -> Self {
Self::new()
}
}
fn extract_bounds(var: VarId, vars: &Vars) -> (f64, f64) {
use crate::variables::views::ViewRaw;
let lower = match var.min_raw(vars) {
Val::ValF(f) => f,
Val::ValI(i) => i as f64,
};
let upper = match var.max_raw(vars) {
Val::ValF(f) => f,
Val::ValI(i) => i as f64,
};
(lower, upper)
}
pub fn apply_lp_solution(
system: &LinearConstraintSystem,
solution: &LpSolution,
ctx: &mut Context,
) -> Option<()> {
use crate::variables::views::View;
if solution.status != LpStatus::Optimal {
return Some(()); }
let tolerance = crate::lpsolver::LpConfig::default().feasibility_tol;
let mut lp_index_to_var: Vec<VarId> = Vec::new();
for &var in &system.variables {
let (lower, upper) = extract_bounds(var, ctx.vars());
if (upper - lower).abs() >= tolerance {
lp_index_to_var.push(var);
}
}
if solution.x.len() != lp_index_to_var.len() {
lp_debug!("LP APPLY: WARNING: LP solution has {} variables but expected {}",
solution.x.len(), lp_index_to_var.len());
return Some(()); }
for (lp_idx, &var_id) in lp_index_to_var.iter().enumerate() {
let lp_value = solution.x[lp_idx];
let (current_lower, current_upper) = extract_bounds(var_id, ctx.vars());
lp_debug!("LP APPLY: var {:?} LP_value={} bounds=[{}, {}]", var_id, lp_value, current_lower, current_upper);
if (current_upper - current_lower).abs() < tolerance {
lp_debug!("LP APPLY: var {:?} is constant, skipping", var_id);
continue;
}
if lp_value < current_lower - tolerance || lp_value > current_upper + tolerance {
continue;
}
if lp_value > current_lower + tolerance {
let new_min = Val::ValF(lp_value);
if var_id.try_set_min(new_min, ctx).is_none() {
return None;
}
}
if lp_value < current_upper - tolerance {
let new_max = Val::ValF(lp_value);
if var_id.try_set_max(new_max, ctx).is_none() {
return None;
}
}
}
Some(())
}
#[cfg(test)]
mod tests {
use super::*;
fn var(index: usize) -> VarId {
VarId::from_index(index)
}
#[test]
fn test_linear_constraint_creation() {
let vars = vec![var(0), var(1)];
let coeffs = vec![2.0, 3.0];
let constraint = LinearConstraint::equality(coeffs.clone(), vars.clone(), 10.0);
assert_eq!(constraint.coefficients, coeffs);
assert_eq!(constraint.variables.len(), vars.len());
assert_eq!(constraint.relation, ConstraintRelation::Equality);
assert_eq!(constraint.rhs, 10.0);
}
#[test]
fn test_constraint_to_standard_form() {
let vars = vec![var(0), var(1)];
let le = LinearConstraint::less_or_equal(vec![1.0, 2.0], vars.clone(), 5.0);
let std = le.to_standard_form();
assert_eq!(std.len(), 1);
assert_eq!(std[0].0, vec![1.0, 2.0]);
assert_eq!(std[0].1, 5.0);
let ge = LinearConstraint::greater_or_equal(vec![1.0, 2.0], vars.clone(), 5.0);
let std = ge.to_standard_form();
assert_eq!(std.len(), 1);
assert_eq!(std[0].0, vec![-1.0, -2.0]);
assert_eq!(std[0].1, -5.0);
let eq = LinearConstraint::equality(vec![1.0, 2.0], vars.clone(), 5.0);
let std = eq.to_standard_form();
assert_eq!(std.len(), 2);
assert_eq!(std[0].0, vec![1.0, 2.0]);
assert_eq!(std[0].1, 5.0);
assert_eq!(std[1].0, vec![-1.0, -2.0]);
assert_eq!(std[1].1, -5.0);
}
#[test]
fn test_linear_system_creation() {
let mut system = LinearConstraintSystem::new();
assert_eq!(system.n_variables(), 0);
assert_eq!(system.n_constraints(), 0);
let vars = vec![var(0), var(1)];
let constraint = LinearConstraint::less_or_equal(vec![1.0, 2.0], vars, 5.0);
system.add_constraint(constraint);
assert_eq!(system.n_variables(), 2);
assert_eq!(system.n_constraints(), 1);
}
#[test]
fn test_system_deduplicates_variables() {
let mut system = LinearConstraintSystem::new();
let c1 = LinearConstraint::less_or_equal(
vec![1.0, 2.0],
vec![var(0), var(1)],
5.0
);
let c2 = LinearConstraint::less_or_equal(
vec![3.0, 1.0],
vec![var(1), var(2)],
10.0
);
system.add_constraint(c1);
system.add_constraint(c2);
assert_eq!(system.n_variables(), 3);
assert_eq!(system.n_constraints(), 2);
}
}