use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "scirs")]
use crate::scirs_stub::{
scirs2_linalg::norm::Norm,
scirs2_optimization::{OptimizationProblem, Optimizer},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyConfig {
pub initial_weight: f64,
pub min_weight: f64,
pub max_weight: f64,
pub adjustment_factor: f64,
pub violation_tolerance: f64,
pub max_iterations: usize,
pub adaptive_scaling: bool,
pub penalty_type: PenaltyType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PenaltyType {
Quadratic,
Linear,
LogBarrier,
Exponential,
AugmentedLagrangian,
}
impl Default for PenaltyConfig {
fn default() -> Self {
Self {
initial_weight: 1.0,
min_weight: 0.001,
max_weight: 1000.0,
adjustment_factor: 2.0,
violation_tolerance: 1e-6,
max_iterations: 100,
adaptive_scaling: true,
penalty_type: PenaltyType::Quadratic,
}
}
}
pub struct PenaltyOptimizer {
config: PenaltyConfig,
constraint_weights: HashMap<String, f64>,
violation_history: Vec<ConstraintViolation>,
#[cfg(feature = "scirs")]
optimizer: Option<Box<dyn Optimizer>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintViolation {
pub constraint_name: String,
pub violation_amount: f64,
pub penalty_weight: f64,
pub iteration: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyOptimizationResult {
pub optimal_weights: HashMap<String, f64>,
pub final_violations: HashMap<String, f64>,
pub converged: bool,
pub iterations: usize,
pub objective_value: f64,
pub constraint_satisfaction: f64,
}
impl PenaltyOptimizer {
pub fn new(config: PenaltyConfig) -> Self {
Self {
config,
constraint_weights: HashMap::new(),
violation_history: Vec::new(),
#[cfg(feature = "scirs")]
optimizer: None,
}
}
pub fn initialize_weights(&mut self, constraints: &[String]) {
for constraint in constraints {
self.constraint_weights
.insert(constraint.clone(), self.config.initial_weight);
}
#[cfg(feature = "scirs")]
{
use crate::scirs_stub::scirs2_optimization::gradient::LBFGS;
self.optimizer = Some(Box::new(LBFGS::new(constraints.len())));
}
}
pub fn optimize_penalties(
&mut self,
model: &CompiledModel,
sample_results: &[(Vec<bool>, f64)],
) -> Result<PenaltyOptimizationResult, Box<dyn std::error::Error>> {
let mut iteration = 0;
let mut converged = false;
while iteration < self.config.max_iterations && !converged {
let violations = self.evaluate_violations(model, sample_results)?;
let max_violation = violations.values().map(|v| v.abs()).fold(0.0, f64::max);
if max_violation < self.config.violation_tolerance {
converged = true;
break;
}
self.update_weights(&violations, iteration)?;
for (name, &violation) in &violations {
self.violation_history.push(ConstraintViolation {
constraint_name: name.clone(),
violation_amount: violation,
penalty_weight: self.constraint_weights[name],
iteration,
});
}
iteration += 1;
}
let final_violations = self.evaluate_violations(model, sample_results)?;
let objective_value = self.calculate_objective(model, sample_results)?;
let constraint_satisfaction = self.calculate_satisfaction_rate(&final_violations);
Ok(PenaltyOptimizationResult {
optimal_weights: self.constraint_weights.clone(),
final_violations,
converged,
iterations: iteration,
objective_value,
constraint_satisfaction,
})
}
fn evaluate_violations(
&self,
model: &CompiledModel,
sample_results: &[(Vec<bool>, f64)],
) -> Result<HashMap<String, f64>, Box<dyn std::error::Error>> {
let mut violations = HashMap::new();
for (constraint_name, constraint_expr) in model.get_constraints() {
let mut total_violation = 0.0;
let mut count = 0;
for (assignment, _energy) in sample_results {
let violation = self.evaluate_constraint_violation(
constraint_expr,
assignment,
model.get_variable_map(),
)?;
total_violation += violation;
count += 1;
}
violations.insert(
constraint_name.clone(),
if count > 0 {
total_violation / count as f64
} else {
0.0
},
);
}
Ok(violations)
}
fn evaluate_constraint_violation(
&self,
_constraint: &ConstraintExpr,
_assignment: &[bool],
_var_map: &HashMap<String, usize>,
) -> Result<f64, Box<dyn std::error::Error>> {
let value: f64 = 0.0;
Ok(match self.config.penalty_type {
PenaltyType::Quadratic => value.powi(2),
PenaltyType::Linear => value.abs(),
PenaltyType::LogBarrier => {
if value > 0.0 {
-value.ln()
} else {
f64::INFINITY
}
}
PenaltyType::Exponential => value.exp_m1(),
PenaltyType::AugmentedLagrangian => {
value.mul_add(value, value.abs())
}
})
}
fn update_weights(
&mut self,
violations: &HashMap<String, f64>,
iteration: usize,
) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(feature = "scirs")]
{
if self.config.adaptive_scaling && self.optimizer.is_some() {
self.update_weights_optimized(violations, iteration)?;
return Ok(());
}
}
for (constraint_name, &violation) in violations {
if let Some(weight) = self.constraint_weights.get_mut(constraint_name) {
if violation.abs() > self.config.violation_tolerance {
*weight = (*weight * self.config.adjustment_factor).min(self.config.max_weight);
} else if violation.abs() < self.config.violation_tolerance * 0.1 {
*weight = (*weight / self.config.adjustment_factor.sqrt())
.max(self.config.min_weight);
}
}
}
Ok(())
}
#[cfg(feature = "scirs")]
fn update_weights_optimized(
&mut self,
violations: &HashMap<String, f64>,
iteration: usize,
) -> Result<(), Box<dyn std::error::Error>> {
use crate::scirs_stub::scirs2_optimization::{Bounds, ObjectiveFunction};
let constraint_names: Vec<_> = violations.keys().cloned().collect();
let current_weights: Array1<f64> = constraint_names
.iter()
.map(|name| self.constraint_weights[name])
.collect();
let violations_vec: Array1<f64> = constraint_names
.iter()
.map(|name| violations[name].abs())
.collect();
let mut objective = WeightOptimizationObjective {
violations: violations_vec,
penalty_type: self.config.penalty_type,
regularization: 0.01, };
let lower_bounds = Array1::from_elem(constraint_names.len(), self.config.min_weight);
let upper_bounds = Array1::from_elem(constraint_names.len(), self.config.max_weight);
let bounds = Bounds::new(lower_bounds, upper_bounds);
if let Some(ref mut optimizer) = self.optimizer {
let mut result =
optimizer.minimize(&objective, ¤t_weights, &bounds, iteration)?;
for (i, name) in constraint_names.iter().enumerate() {
self.constraint_weights.insert(name.clone(), result.x[i]);
}
}
Ok(())
}
fn calculate_objective(
&self,
model: &CompiledModel,
sample_results: &[(Vec<bool>, f64)],
) -> Result<f64, Box<dyn std::error::Error>> {
let mut total_objective = 0.0;
for (assignment, energy) in sample_results {
let mut penalized_objective = *energy;
for (constraint_name, constraint_expr) in model.get_constraints() {
let violation = self.evaluate_constraint_violation(
constraint_expr,
assignment,
model.get_variable_map(),
)?;
let weight = self
.constraint_weights
.get(constraint_name)
.copied()
.unwrap_or(1.0);
penalized_objective += weight * violation;
}
total_objective += penalized_objective;
}
Ok(total_objective / sample_results.len() as f64)
}
fn calculate_satisfaction_rate(&self, violations: &HashMap<String, f64>) -> f64 {
let satisfied = violations
.values()
.filter(|&&v| v.abs() < self.config.violation_tolerance)
.count();
if violations.is_empty() {
1.0
} else {
satisfied as f64 / violations.len() as f64
}
}
pub fn get_weight(&self, constraint_name: &str) -> Option<f64> {
self.constraint_weights.get(constraint_name).copied()
}
pub fn get_violation_history(&self) -> &[ConstraintViolation] {
&self.violation_history
}
pub fn export_config(&self) -> PenaltyExport {
PenaltyExport {
config: self.config.clone(),
weights: self.constraint_weights.clone(),
final_violations: self
.violation_history
.iter()
.filter(|v| {
v.iteration
== self
.violation_history
.iter()
.map(|h| h.iteration)
.max()
.unwrap_or(0)
})
.map(|v| (v.constraint_name.clone(), v.violation_amount))
.collect(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyExport {
pub config: PenaltyConfig,
pub weights: HashMap<String, f64>,
pub final_violations: HashMap<String, f64>,
}
#[cfg(feature = "scirs")]
struct WeightOptimizationObjective {
violations: Array1<f64>,
penalty_type: PenaltyType,
regularization: f64,
}
#[cfg(feature = "scirs")]
impl crate::scirs_stub::scirs2_optimization::ObjectiveFunction for WeightOptimizationObjective {
fn evaluate(&self, weights: &Array1<f64>) -> f64 {
let weighted_violations = weights * &self.violations;
let total_violation = weighted_violations.sum();
let regularization = self.regularization * weights.dot(weights);
total_violation + regularization
}
fn gradient(&self, weights: &Array1<f64>) -> Array1<f64> {
&self.violations + 2.0 * self.regularization * weights
}
}
#[derive(Debug, Clone)]
pub struct CompiledModel {
constraints: HashMap<String, ConstraintExpr>,
variable_map: HashMap<String, usize>,
}
impl Default for CompiledModel {
fn default() -> Self {
Self::new()
}
}
impl CompiledModel {
pub fn new() -> Self {
Self {
constraints: HashMap::new(),
variable_map: HashMap::new(),
}
}
pub const fn get_constraints(&self) -> &HashMap<String, ConstraintExpr> {
&self.constraints
}
pub const fn get_variable_map(&self) -> &HashMap<String, usize> {
&self.variable_map
}
pub fn to_qubo(&self) -> (Array2<f64>, HashMap<String, usize>) {
let size = self.variable_map.len();
(Array2::zeros((size, size)), self.variable_map.clone())
}
}
#[derive(Debug, Clone)]
pub struct ConstraintExpr {
pub expression: String,
}
trait TermEvaluator {
fn evaluate_with_assignment(
&self,
assignment: &[bool],
var_map: &HashMap<String, usize>,
) -> Result<f64, Box<dyn std::error::Error>>;
}
pub fn analyze_penalty_landscape(config: &PenaltyConfig, violations: &[f64]) -> PenaltyAnalysis {
let weights = Array1::linspace(config.min_weight, config.max_weight, 100);
let mut penalties = Vec::new();
for &weight in &weights {
let penalty_values: Vec<f64> = violations
.iter()
.map(|&v| calculate_penalty(v, weight, config.penalty_type))
.collect();
penalties.push(PenaltyPoint {
weight,
avg_penalty: penalty_values.iter().sum::<f64>() / penalty_values.len() as f64,
max_penalty: penalty_values.iter().fold(0.0, |a, &b| a.max(b)),
min_penalty: penalty_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
});
}
PenaltyAnalysis {
penalty_points: penalties,
optimal_weight: find_optimal_weight(&weights, violations, config),
sensitivity: calculate_sensitivity(violations, config),
}
}
fn calculate_penalty(violation: f64, weight: f64, penalty_type: PenaltyType) -> f64 {
weight
* match penalty_type {
PenaltyType::Quadratic => violation.powi(2),
PenaltyType::Linear => violation.abs(),
PenaltyType::LogBarrier => {
if violation > 0.0 {
-violation.ln()
} else {
1e10 }
}
PenaltyType::Exponential => violation.exp_m1(),
PenaltyType::AugmentedLagrangian => violation.mul_add(violation, violation.abs()),
}
}
fn find_optimal_weight(weights: &Array1<f64>, violations: &[f64], config: &PenaltyConfig) -> f64 {
let target_penalty = violations.len() as f64 * config.violation_tolerance;
let mut best_weight = config.initial_weight;
let mut best_diff = f64::INFINITY;
for &weight in weights {
let total_penalty: f64 = violations
.iter()
.map(|&v| calculate_penalty(v, weight, config.penalty_type))
.sum();
let diff = (total_penalty - target_penalty).abs();
if diff < best_diff {
best_diff = diff;
best_weight = weight;
}
}
best_weight
}
fn calculate_sensitivity(violations: &[f64], config: &PenaltyConfig) -> f64 {
if violations.is_empty() {
return 0.0;
}
let weight = config.initial_weight;
let penalties: Vec<f64> = violations
.iter()
.map(|&v| calculate_penalty(v, weight, config.penalty_type))
.collect();
let delta = 0.01 * weight;
let penalties_delta: Vec<f64> = violations
.iter()
.map(|&v| calculate_penalty(v, weight + delta, config.penalty_type))
.collect();
let derivatives: Vec<f64> = penalties
.iter()
.zip(penalties_delta.iter())
.map(|(&p1, &p2)| (p2 - p1) / delta)
.collect();
derivatives.iter().sum::<f64>() / derivatives.len() as f64
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyAnalysis {
pub penalty_points: Vec<PenaltyPoint>,
pub optimal_weight: f64,
pub sensitivity: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PenaltyPoint {
pub weight: f64,
pub avg_penalty: f64,
pub max_penalty: f64,
pub min_penalty: f64,
}