use super::implicit_diff;
use super::types::{DiffLPConfig, DiffLPResult, ImplicitGradient};
use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone)]
pub struct DifferentiableLP {
pub c: Vec<f64>,
pub a_eq: Vec<Vec<f64>>,
pub b_eq: Vec<f64>,
pub g: Vec<Vec<f64>>,
pub h: Vec<f64>,
}
impl DifferentiableLP {
pub fn new(
c: Vec<f64>,
a_eq: Vec<Vec<f64>>,
b_eq: Vec<f64>,
g: Vec<Vec<f64>>,
h: Vec<f64>,
) -> OptimizeResult<Self> {
let n = c.len();
for (i, row) in a_eq.iter().enumerate() {
if row.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"A_eq row {} has length {} but expected {}",
i,
row.len(),
n
)));
}
}
if a_eq.len() != b_eq.len() {
return Err(OptimizeError::InvalidInput(format!(
"A_eq has {} rows but b_eq has length {}",
a_eq.len(),
b_eq.len()
)));
}
for (i, row) in g.iter().enumerate() {
if row.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"G row {} has length {} but expected {}",
i,
row.len(),
n
)));
}
}
if g.len() != h.len() {
return Err(OptimizeError::InvalidInput(format!(
"G has {} rows but h has length {}",
g.len(),
h.len()
)));
}
Ok(Self {
c,
a_eq,
b_eq,
g,
h,
})
}
pub fn n(&self) -> usize {
self.c.len()
}
pub fn forward(&self, config: &DiffLPConfig) -> OptimizeResult<DiffLPResult> {
let n = self.n();
let m = self.h.len();
let p = self.b_eq.len();
let mut q = vec![vec![0.0; n]; n];
for i in 0..n {
q[i][i] = config.regularization;
}
let qp = super::diff_qp::DifferentiableQP {
q,
c: self.c.clone(),
g: self.g.clone(),
h: self.h.clone(),
a: self.a_eq.clone(),
b: self.b_eq.clone(),
};
let qp_config = super::types::DiffQPConfig {
tolerance: config.tolerance,
max_iterations: config.max_iterations,
regularization: config.regularization,
backward_mode: super::types::BackwardMode::FullDifferentiation,
};
let qp_result = qp.forward(&qp_config)?;
let mut obj = 0.0;
for i in 0..n {
obj += self.c[i] * qp_result.optimal_x[i];
}
Ok(DiffLPResult {
optimal_x: qp_result.optimal_x,
optimal_lambda: qp_result.optimal_lambda,
optimal_nu: qp_result.optimal_nu,
objective: obj,
converged: qp_result.converged,
iterations: qp_result.iterations,
})
}
pub fn backward(
&self,
result: &DiffLPResult,
dl_dx: &[f64],
config: &DiffLPConfig,
) -> OptimizeResult<ImplicitGradient> {
let n = self.n();
if dl_dx.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"dl_dx length {} != n {}",
dl_dx.len(),
n
)));
}
let mut q = vec![vec![0.0; n]; n];
for i in 0..n {
q[i][i] = config.regularization;
}
implicit_diff::compute_active_set_implicit_gradient(
&q,
&self.g,
&self.h,
&self.a_eq,
&result.optimal_x,
&result.optimal_lambda,
&result.optimal_nu,
dl_dx,
config.active_constraint_tol,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lp_forward_simple() {
let lp = DifferentiableLP::new(
vec![-1.0, -1.0],
vec![],
vec![],
vec![
vec![1.0, 1.0], vec![-1.0, 0.0], vec![0.0, -1.0], ],
vec![1.0, 0.0, 0.0],
)
.expect("LP creation failed");
let config = DiffLPConfig::default();
let result = lp.forward(&config).expect("Forward failed");
assert!(result.converged, "LP should converge");
let sum: f64 = result.optimal_x.iter().sum();
assert!((sum - 1.0).abs() < 0.1, "x+y = {} (expected ~1.0)", sum);
assert!(
(result.objective - (-1.0)).abs() < 0.1,
"obj = {} (expected ~-1.0)",
result.objective
);
}
#[test]
fn test_lp_with_equality() {
let lp = DifferentiableLP::new(
vec![-1.0, 0.0],
vec![vec![1.0, 1.0]],
vec![1.0],
vec![
vec![-1.0, 0.0], vec![0.0, -1.0], ],
vec![0.0, 0.0],
)
.expect("LP creation failed");
let config = DiffLPConfig::default();
let result = lp.forward(&config).expect("Forward failed");
assert!(result.converged);
assert!(
(result.optimal_x[0] - 1.0).abs() < 0.1,
"x = {} (expected ~1.0)",
result.optimal_x[0]
);
}
#[test]
fn test_lp_backward() {
let lp = DifferentiableLP::new(
vec![-1.0, -1.0],
vec![],
vec![],
vec![vec![1.0, 1.0], vec![-1.0, 0.0], vec![0.0, -1.0]],
vec![1.0, 0.0, 0.0],
)
.expect("LP creation failed");
let config = DiffLPConfig::default();
let result = lp.forward(&config).expect("Forward failed");
let dl_dx = vec![1.0, 1.0];
let grad = lp
.backward(&result, &dl_dx, &config)
.expect("Backward failed");
assert_eq!(grad.dl_dc.len(), 2);
assert!(grad.dl_dc[0].is_finite());
assert!(grad.dl_dc[1].is_finite());
}
#[test]
fn test_lp_dimension_validation() {
let result = DifferentiableLP::new(
vec![1.0, 2.0],
vec![vec![1.0]], vec![1.0],
vec![],
vec![],
);
assert!(result.is_err());
}
}