use crate::constraint::{PenaltyFunction, ViolationComputable};
use crate::error::LogicResult;
pub struct DifferentiableProjection {
temperature: f32,
max_iterations: usize,
}
impl DifferentiableProjection {
pub fn new(temperature: f32) -> Self {
assert!(temperature > 0.0, "Temperature must be positive");
Self {
temperature,
max_iterations: 10,
}
}
pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
self.max_iterations = max_iter;
self
}
pub fn soft_project_box(&self, x: f32, lower: f32, upper: f32) -> f32 {
let tau = self.temperature;
let soft_lower = x + tau * ((lower - x) / tau).exp().ln_1p();
let soft_upper = x - tau * ((x - upper) / tau).exp().ln_1p();
soft_lower.min(soft_upper)
}
pub fn barrier_project(&self, value: f32, gradient: f32, constraint_val: f32) -> f32 {
if constraint_val >= 0.0 {
value - self.temperature * gradient * constraint_val
} else {
let barrier_grad = -self.temperature / constraint_val;
value - self.temperature * gradient * barrier_grad
}
}
pub fn temperature(&self) -> f32 {
self.temperature
}
}
pub struct ConstraintAwareLoss<C> {
constraints: Vec<C>,
penalty_function: PenaltyFunction,
constraint_weight: f32,
}
impl<C: ViolationComputable> ConstraintAwareLoss<C> {
pub fn new(constraints: Vec<C>, penalty: PenaltyFunction, weight: f32) -> Self {
Self {
constraints,
penalty_function: penalty,
constraint_weight: weight,
}
}
pub fn compute_loss(&self, prediction: &[f32], task_loss: f32) -> f32 {
let violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(prediction))
.sum();
task_loss
+ self
.penalty_function
.compute(violation, self.constraint_weight)
}
pub fn constraint_penalty(&self, prediction: &[f32]) -> f32 {
let violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(prediction))
.sum();
self.penalty_function
.compute(violation, self.constraint_weight)
}
pub fn all_satisfied(&self, prediction: &[f32]) -> bool {
self.constraints.iter().all(|c| c.check(prediction))
}
pub fn constraint_weight(&self) -> f32 {
self.constraint_weight
}
pub fn set_constraint_weight(&mut self, weight: f32) {
self.constraint_weight = weight;
}
}
pub struct LagrangianRelaxation {
multipliers: Vec<f32>,
learning_rate_multipliers: f32,
max_multiplier: f32,
}
impl LagrangianRelaxation {
pub fn new(num_constraints: usize) -> Self {
Self {
multipliers: vec![0.0; num_constraints],
learning_rate_multipliers: 0.01,
max_multiplier: 100.0,
}
}
pub fn with_multiplier_lr(mut self, lr: f32) -> Self {
self.learning_rate_multipliers = lr;
self
}
pub fn with_max_multiplier(mut self, max_val: f32) -> Self {
self.max_multiplier = max_val;
self
}
pub fn compute_lagrangian<C: ViolationComputable>(
&self,
prediction: &[f32],
constraints: &[C],
task_loss: f32,
) -> f32 {
let constraint_terms: f32 = self
.multipliers
.iter()
.zip(constraints.iter())
.map(|(&lambda, c)| lambda * c.violation(prediction))
.sum();
task_loss + constraint_terms
}
pub fn update_multipliers<C: ViolationComputable>(
&mut self,
prediction: &[f32],
constraints: &[C],
) {
for (lambda, constraint) in self.multipliers.iter_mut().zip(constraints.iter()) {
let violation = constraint.violation(prediction);
*lambda = (*lambda + self.learning_rate_multipliers * violation)
.clamp(0.0, self.max_multiplier);
}
}
pub fn multipliers(&self) -> &[f32] {
&self.multipliers
}
pub fn reset_multipliers(&mut self) {
self.multipliers.fill(0.0);
}
pub fn average_multiplier(&self) -> f32 {
if self.multipliers.is_empty() {
0.0
} else {
self.multipliers.iter().sum::<f32>() / self.multipliers.len() as f32
}
}
}
pub struct PenaltyMethod {
penalty_weight: f32,
penalty_increase_factor: f32,
penalty_function: PenaltyFunction,
}
impl PenaltyMethod {
pub fn new(penalty_function: PenaltyFunction) -> Self {
Self {
penalty_weight: 1.0,
penalty_increase_factor: 10.0,
penalty_function,
}
}
pub fn with_penalty_weight(mut self, weight: f32) -> Self {
self.penalty_weight = weight;
self
}
pub fn with_increase_factor(mut self, factor: f32) -> Self {
self.penalty_increase_factor = factor;
self
}
pub fn compute_loss<C: ViolationComputable>(
&self,
prediction: &[f32],
constraints: &[C],
task_loss: f32,
) -> f32 {
let total_violation: f32 = constraints.iter().map(|c| c.violation(prediction)).sum();
task_loss
+ self
.penalty_function
.compute(total_violation, self.penalty_weight)
}
pub fn increase_penalty(&mut self) {
self.penalty_weight *= self.penalty_increase_factor;
}
pub fn penalty_weight(&self) -> f32 {
self.penalty_weight
}
pub fn reset_penalty(&mut self, initial_weight: f32) {
self.penalty_weight = initial_weight;
}
}
pub struct BarrierMethod {
barrier_weight: f32,
barrier_decrease_factor: f32,
}
impl BarrierMethod {
pub fn new() -> Self {
Self {
barrier_weight: 1.0,
barrier_decrease_factor: 0.1,
}
}
pub fn with_barrier_weight(mut self, weight: f32) -> Self {
self.barrier_weight = weight;
self
}
pub fn with_decrease_factor(mut self, factor: f32) -> Self {
self.barrier_decrease_factor = factor;
self
}
pub fn compute_loss<C: ViolationComputable>(
&self,
prediction: &[f32],
constraints: &[C],
task_loss: f32,
) -> LogicResult<f32> {
let mut barrier_term = 0.0;
for constraint in constraints {
let violation = constraint.violation(prediction);
if violation > 0.0 {
return Ok(f32::MAX); }
let epsilon = 1e-6;
let g_x = -violation - epsilon;
if g_x >= 0.0 {
return Ok(f32::MAX);
}
barrier_term -= g_x.ln();
}
Ok(task_loss + self.barrier_weight * barrier_term)
}
pub fn decrease_barrier(&mut self) {
self.barrier_weight *= self.barrier_decrease_factor;
}
pub fn barrier_weight(&self) -> f32 {
self.barrier_weight
}
pub fn reset_barrier(&mut self, initial_weight: f32) {
self.barrier_weight = initial_weight;
}
}
impl Default for BarrierMethod {
fn default() -> Self {
Self::new()
}
}
pub struct AdaptiveWeighting {
constraint_weights: Vec<f32>,
learning_rate: f32,
min_weight: f32,
max_weight: f32,
}
impl AdaptiveWeighting {
pub fn new(num_constraints: usize) -> Self {
Self {
constraint_weights: vec![1.0; num_constraints],
learning_rate: 0.01,
min_weight: 0.01,
max_weight: 100.0,
}
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn with_weight_bounds(mut self, min: f32, max: f32) -> Self {
self.min_weight = min;
self.max_weight = max;
self
}
pub fn update_weights<C: ViolationComputable>(
&mut self,
prediction: &[f32],
constraints: &[C],
) {
for (weight, constraint) in self.constraint_weights.iter_mut().zip(constraints.iter()) {
let violation = constraint.violation(prediction);
let adjustment = if violation > 0.0 {
self.learning_rate * violation
} else {
-self.learning_rate * 0.1 };
*weight = (*weight + adjustment).clamp(self.min_weight, self.max_weight);
}
}
pub fn weights(&self) -> &[f32] {
&self.constraint_weights
}
pub fn weighted_penalty<C: ViolationComputable>(
&self,
prediction: &[f32],
constraints: &[C],
penalty: PenaltyFunction,
) -> f32 {
self.constraint_weights
.iter()
.zip(constraints.iter())
.map(|(&weight, c)| {
let violation = c.violation(prediction);
penalty.compute(violation, weight)
})
.sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constraint::ConstraintBuilder;
#[test]
fn test_differentiable_projection() {
let proj = DifferentiableProjection::new(0.1);
let result = proj.soft_project_box(-0.5, 0.0, 1.0);
assert!(result < 0.2);
let result = proj.soft_project_box(1.5, 0.0, 1.0);
assert!(result > 0.8);
let result = proj.soft_project_box(0.5, 0.0, 1.0);
assert!((result - 0.5).abs() < 0.2); }
#[test]
fn test_constraint_aware_loss() {
let constraint = ConstraintBuilder::new()
.name("max_val")
.less_eq(10.0)
.build()
.unwrap();
let loss = ConstraintAwareLoss::new(vec![constraint], PenaltyFunction::L2, 1.0);
assert!(loss.all_satisfied(&[5.0]));
let penalty = loss.constraint_penalty(&[5.0]);
assert_eq!(penalty, 0.0);
assert!(!loss.all_satisfied(&[15.0]));
let penalty = loss.constraint_penalty(&[15.0]);
assert!((penalty - 25.0).abs() < 0.01); }
#[test]
fn test_lagrangian_relaxation() {
let constraint = ConstraintBuilder::new()
.name("upper")
.less_eq(5.0)
.build()
.unwrap();
let mut lagrangian = LagrangianRelaxation::new(1).with_multiplier_lr(0.1);
let task_loss = 1.0;
let lag =
lagrangian.compute_lagrangian(&[3.0], std::slice::from_ref(&constraint), task_loss);
assert_eq!(lag, task_loss);
lagrangian.update_multipliers(&[10.0], std::slice::from_ref(&constraint));
let lag = lagrangian.compute_lagrangian(&[10.0], &[constraint], task_loss);
assert!(lag > task_loss); }
#[test]
fn test_penalty_method() {
let mut penalty = PenaltyMethod::new(PenaltyFunction::L2).with_penalty_weight(1.0);
let constraint = ConstraintBuilder::new()
.name("bound")
.less_eq(10.0)
.build()
.unwrap();
let task_loss = 2.0;
let total = penalty.compute_loss(&[15.0], &[constraint], task_loss);
assert!((total - 27.0).abs() < 0.01);
penalty.increase_penalty();
assert_eq!(penalty.penalty_weight(), 10.0);
}
#[test]
fn test_adaptive_weighting() {
let constraint1 = ConstraintBuilder::new()
.name("c1")
.less_eq(5.0)
.build()
.unwrap();
let constraint2 = ConstraintBuilder::new()
.name("c2")
.greater_eq(0.0)
.build()
.unwrap();
let mut adaptive = AdaptiveWeighting::new(2).with_learning_rate(0.1);
assert_eq!(adaptive.weights(), &[1.0, 1.0]);
adaptive.update_weights(&[10.0], &[constraint1.clone(), constraint2.clone()]);
assert!(adaptive.weights()[0] > 1.0);
assert!(adaptive.weights()[1] <= 1.0);
}
}