use super::diff_qp::DifferentiableQP;
use super::types::{DiffQPConfig, DiffQPResult, ImplicitGradient};
use crate::error::OptimizeResult;
pub trait OptNetLayer {
type ForwardResult;
fn forward(&self) -> OptimizeResult<Self::ForwardResult>;
fn backward(
&self,
result: &Self::ForwardResult,
dl_dx: &[f64],
) -> OptimizeResult<ImplicitGradient>;
}
#[derive(Debug, Clone)]
pub struct StandardOptNetLayer {
pub qp: DifferentiableQP,
pub config: DiffQPConfig,
}
impl StandardOptNetLayer {
pub fn new(qp: DifferentiableQP, config: DiffQPConfig) -> Self {
Self { qp, config }
}
pub fn forward_batch(
qps: &[DifferentiableQP],
config: &DiffQPConfig,
) -> OptimizeResult<Vec<DiffQPResult>> {
DifferentiableQP::batched_forward(qps, config)
}
}
impl OptNetLayer for StandardOptNetLayer {
type ForwardResult = DiffQPResult;
fn forward(&self) -> OptimizeResult<DiffQPResult> {
self.qp.forward(&self.config)
}
fn backward(&self, result: &DiffQPResult, dl_dx: &[f64]) -> OptimizeResult<ImplicitGradient> {
self.qp.backward(result, dl_dx, &self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_trait_dispatch() {
let qp = DifferentiableQP::new(
vec![vec![2.0, 0.0], vec![0.0, 2.0]],
vec![1.0, 2.0],
vec![],
vec![],
vec![],
vec![],
)
.expect("QP creation failed");
let layer = StandardOptNetLayer::new(qp, DiffQPConfig::default());
let result = layer.forward().expect("Forward failed");
assert!(result.converged);
let dl_dx = vec![1.0, 0.0];
let grad = layer.backward(&result, &dl_dx).expect("Backward failed");
assert_eq!(grad.dl_dc.len(), 2);
}
#[test]
fn test_layer_batch_interface() {
let qp1 = DifferentiableQP::new(vec![vec![2.0]], vec![1.0], vec![], vec![], vec![], vec![])
.expect("QP1 creation failed");
let qp2 = DifferentiableQP::new(vec![vec![4.0]], vec![2.0], vec![], vec![], vec![], vec![])
.expect("QP2 creation failed");
let config = DiffQPConfig::default();
let results =
StandardOptNetLayer::forward_batch(&[qp1, qp2], &config).expect("Batch failed");
assert_eq!(results.len(), 2);
assert!(
(results[0].optimal_x[0] - (-0.5)).abs() < 1e-3,
"batch[0].x = {}",
results[0].optimal_x[0]
);
assert!(
(results[1].optimal_x[0] - (-0.5)).abs() < 1e-2,
"batch[1].x = {}",
results[1].optimal_x[0]
);
}
}