1use scirs2_core::ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[cfg(feature = "scirs")]
11use crate::scirs_stub::{
12 scirs2_linalg::norm::Norm,
13 scirs2_optimization::{OptimizationProblem, Optimizer},
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PenaltyConfig {
19 pub initial_weight: f64,
21 pub min_weight: f64,
23 pub max_weight: f64,
25 pub adjustment_factor: f64,
27 pub violation_tolerance: f64,
29 pub max_iterations: usize,
31 pub adaptive_scaling: bool,
33 pub penalty_type: PenaltyType,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum PenaltyType {
40 Quadratic,
42 Linear,
44 LogBarrier,
46 Exponential,
48 AugmentedLagrangian,
50}
51
52impl Default for PenaltyConfig {
53 fn default() -> Self {
54 Self {
55 initial_weight: 1.0,
56 min_weight: 0.001,
57 max_weight: 1000.0,
58 adjustment_factor: 2.0,
59 violation_tolerance: 1e-6,
60 max_iterations: 100,
61 adaptive_scaling: true,
62 penalty_type: PenaltyType::Quadratic,
63 }
64 }
65}
66
67pub struct PenaltyOptimizer {
69 config: PenaltyConfig,
70 constraint_weights: HashMap<String, f64>,
71 violation_history: Vec<ConstraintViolation>,
72 #[cfg(feature = "scirs")]
73 optimizer: Option<Box<dyn Optimizer>>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ConstraintViolation {
79 pub constraint_name: String,
80 pub violation_amount: f64,
81 pub penalty_weight: f64,
82 pub iteration: usize,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PenaltyOptimizationResult {
88 pub optimal_weights: HashMap<String, f64>,
89 pub final_violations: HashMap<String, f64>,
90 pub converged: bool,
91 pub iterations: usize,
92 pub objective_value: f64,
93 pub constraint_satisfaction: f64,
94}
95
96impl PenaltyOptimizer {
97 pub fn new(config: PenaltyConfig) -> Self {
99 Self {
100 config,
101 constraint_weights: HashMap::new(),
102 violation_history: Vec::new(),
103 #[cfg(feature = "scirs")]
104 optimizer: None,
105 }
106 }
107
108 pub fn initialize_weights(&mut self, constraints: &[String]) {
110 for constraint in constraints {
111 self.constraint_weights
112 .insert(constraint.clone(), self.config.initial_weight);
113 }
114
115 #[cfg(feature = "scirs")]
116 {
117 use crate::scirs_stub::scirs2_optimization::gradient::LBFGS;
119 self.optimizer = Some(Box::new(LBFGS::new(constraints.len())));
120 }
121 }
122
123 pub fn optimize_penalties(
125 &mut self,
126 model: &CompiledModel,
127 sample_results: &[(Vec<bool>, f64)],
128 ) -> Result<PenaltyOptimizationResult, Box<dyn std::error::Error>> {
129 let mut iteration = 0;
130 let mut converged = false;
131
132 while iteration < self.config.max_iterations && !converged {
133 let violations = self.evaluate_violations(model, sample_results)?;
135
136 let max_violation = violations.values().map(|v| v.abs()).fold(0.0, f64::max);
138
139 if max_violation < self.config.violation_tolerance {
140 converged = true;
141 break;
142 }
143
144 self.update_weights(&violations, iteration)?;
146
147 for (name, &violation) in &violations {
149 self.violation_history.push(ConstraintViolation {
150 constraint_name: name.clone(),
151 violation_amount: violation,
152 penalty_weight: self.constraint_weights[name],
153 iteration,
154 });
155 }
156
157 iteration += 1;
158 }
159
160 let final_violations = self.evaluate_violations(model, sample_results)?;
162 let objective_value = self.calculate_objective(model, sample_results)?;
163 let constraint_satisfaction = self.calculate_satisfaction_rate(&final_violations);
164
165 Ok(PenaltyOptimizationResult {
166 optimal_weights: self.constraint_weights.clone(),
167 final_violations,
168 converged,
169 iterations: iteration,
170 objective_value,
171 constraint_satisfaction,
172 })
173 }
174
175 fn evaluate_violations(
177 &self,
178 model: &CompiledModel,
179 sample_results: &[(Vec<bool>, f64)],
180 ) -> Result<HashMap<String, f64>, Box<dyn std::error::Error>> {
181 let mut violations = HashMap::new();
182
183 for (constraint_name, constraint_expr) in model.get_constraints() {
185 let mut total_violation = 0.0;
186 let mut count = 0;
187
188 for (assignment, _energy) in sample_results {
190 let violation = self.evaluate_constraint_violation(
191 constraint_expr,
192 assignment,
193 model.get_variable_map(),
194 )?;
195
196 total_violation += violation;
197 count += 1;
198 }
199
200 violations.insert(
202 constraint_name.clone(),
203 if count > 0 {
204 total_violation / count as f64
205 } else {
206 0.0
207 },
208 );
209 }
210
211 Ok(violations)
212 }
213
214 fn evaluate_constraint_violation(
216 &self,
217 _constraint: &ConstraintExpr,
218 _assignment: &[bool],
219 _var_map: &HashMap<String, usize>,
220 ) -> Result<f64, Box<dyn std::error::Error>> {
221 let value: f64 = 0.0; Ok(match self.config.penalty_type {
226 PenaltyType::Quadratic => value.powi(2),
227 PenaltyType::Linear => value.abs(),
228 PenaltyType::LogBarrier => {
229 if value > 0.0 {
230 -value.ln()
231 } else {
232 f64::INFINITY
233 }
234 }
235 PenaltyType::Exponential => value.exp_m1(),
236 PenaltyType::AugmentedLagrangian => {
237 value.mul_add(value, value.abs())
239 }
240 })
241 }
242
243 fn update_weights(
245 &mut self,
246 violations: &HashMap<String, f64>,
247 iteration: usize,
248 ) -> Result<(), Box<dyn std::error::Error>> {
249 #[cfg(feature = "scirs")]
250 {
251 if self.config.adaptive_scaling && self.optimizer.is_some() {
252 self.update_weights_optimized(violations, iteration)?;
254 return Ok(());
255 }
256 }
257
258 for (constraint_name, &violation) in violations {
260 if let Some(weight) = self.constraint_weights.get_mut(constraint_name) {
261 if violation.abs() > self.config.violation_tolerance {
262 *weight = (*weight * self.config.adjustment_factor).min(self.config.max_weight);
264 } else if violation.abs() < self.config.violation_tolerance * 0.1 {
265 *weight = (*weight / self.config.adjustment_factor.sqrt())
267 .max(self.config.min_weight);
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 #[cfg(feature = "scirs")]
276 fn update_weights_optimized(
278 &mut self,
279 violations: &HashMap<String, f64>,
280 iteration: usize,
281 ) -> Result<(), Box<dyn std::error::Error>> {
282 use crate::scirs_stub::scirs2_optimization::{Bounds, ObjectiveFunction};
283
284 let constraint_names: Vec<_> = violations.keys().cloned().collect();
286 let current_weights: Array1<f64> = constraint_names
287 .iter()
288 .map(|name| self.constraint_weights[name])
289 .collect();
290
291 let violations_vec: Array1<f64> = constraint_names
293 .iter()
294 .map(|name| violations[name].abs())
295 .collect();
296
297 let mut objective = WeightOptimizationObjective {
298 violations: violations_vec,
299 penalty_type: self.config.penalty_type,
300 regularization: 0.01, };
302
303 let lower_bounds = Array1::from_elem(constraint_names.len(), self.config.min_weight);
305 let upper_bounds = Array1::from_elem(constraint_names.len(), self.config.max_weight);
306 let bounds = Bounds::new(lower_bounds, upper_bounds);
307
308 if let Some(ref mut optimizer) = self.optimizer {
310 let mut result =
311 optimizer.minimize(&objective, ¤t_weights, &bounds, iteration)?;
312
313 for (i, name) in constraint_names.iter().enumerate() {
315 self.constraint_weights.insert(name.clone(), result.x[i]);
316 }
317 }
318
319 Ok(())
320 }
321
322 fn calculate_objective(
324 &self,
325 model: &CompiledModel,
326 sample_results: &[(Vec<bool>, f64)],
327 ) -> Result<f64, Box<dyn std::error::Error>> {
328 let mut total_objective = 0.0;
329
330 for (assignment, energy) in sample_results {
331 let mut penalized_objective = *energy;
333
334 for (constraint_name, constraint_expr) in model.get_constraints() {
336 let violation = self.evaluate_constraint_violation(
337 constraint_expr,
338 assignment,
339 model.get_variable_map(),
340 )?;
341
342 let weight = self
343 .constraint_weights
344 .get(constraint_name)
345 .copied()
346 .unwrap_or(1.0);
347
348 penalized_objective += weight * violation;
349 }
350
351 total_objective += penalized_objective;
352 }
353
354 Ok(total_objective / sample_results.len() as f64)
355 }
356
357 fn calculate_satisfaction_rate(&self, violations: &HashMap<String, f64>) -> f64 {
359 let satisfied = violations
360 .values()
361 .filter(|&&v| v.abs() < self.config.violation_tolerance)
362 .count();
363
364 if violations.is_empty() {
365 1.0
366 } else {
367 satisfied as f64 / violations.len() as f64
368 }
369 }
370
371 pub fn get_weight(&self, constraint_name: &str) -> Option<f64> {
373 self.constraint_weights.get(constraint_name).copied()
374 }
375
376 pub fn get_violation_history(&self) -> &[ConstraintViolation] {
378 &self.violation_history
379 }
380
381 pub fn export_config(&self) -> PenaltyExport {
383 PenaltyExport {
384 config: self.config.clone(),
385 weights: self.constraint_weights.clone(),
386 final_violations: self
387 .violation_history
388 .iter()
389 .filter(|v| {
390 v.iteration
391 == self
392 .violation_history
393 .iter()
394 .map(|h| h.iteration)
395 .max()
396 .unwrap_or(0)
397 })
398 .map(|v| (v.constraint_name.clone(), v.violation_amount))
399 .collect(),
400 }
401 }
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct PenaltyExport {
407 pub config: PenaltyConfig,
408 pub weights: HashMap<String, f64>,
409 pub final_violations: HashMap<String, f64>,
410}
411
412#[cfg(feature = "scirs")]
413struct WeightOptimizationObjective {
415 violations: Array1<f64>,
416 penalty_type: PenaltyType,
417 regularization: f64,
418}
419
420#[cfg(feature = "scirs")]
421impl crate::scirs_stub::scirs2_optimization::ObjectiveFunction for WeightOptimizationObjective {
422 fn evaluate(&self, weights: &Array1<f64>) -> f64 {
423 let weighted_violations = weights * &self.violations;
425 let total_violation = weighted_violations.sum();
426
427 let regularization = self.regularization * weights.dot(weights);
429
430 total_violation + regularization
431 }
432
433 fn gradient(&self, weights: &Array1<f64>) -> Array1<f64> {
434 &self.violations + 2.0 * self.regularization * weights
436 }
437}
438
439#[derive(Debug, Clone)]
441pub struct CompiledModel {
442 constraints: HashMap<String, ConstraintExpr>,
443 variable_map: HashMap<String, usize>,
444}
445
446impl Default for CompiledModel {
447 fn default() -> Self {
448 Self::new()
449 }
450}
451
452impl CompiledModel {
453 pub fn new() -> Self {
454 Self {
455 constraints: HashMap::new(),
456 variable_map: HashMap::new(),
457 }
458 }
459
460 pub const fn get_constraints(&self) -> &HashMap<String, ConstraintExpr> {
461 &self.constraints
462 }
463
464 pub const fn get_variable_map(&self) -> &HashMap<String, usize> {
465 &self.variable_map
466 }
467
468 pub fn to_qubo(&self) -> (Array2<f64>, HashMap<String, usize>) {
469 let size = self.variable_map.len();
470 (Array2::zeros((size, size)), self.variable_map.clone())
471 }
472}
473
474#[derive(Debug, Clone)]
476pub struct ConstraintExpr {
477 pub expression: String,
478}
479
480trait TermEvaluator {
482 fn evaluate_with_assignment(
483 &self,
484 assignment: &[bool],
485 var_map: &HashMap<String, usize>,
486 ) -> Result<f64, Box<dyn std::error::Error>>;
487}
488
489pub fn analyze_penalty_landscape(config: &PenaltyConfig, violations: &[f64]) -> PenaltyAnalysis {
491 let weights = Array1::linspace(config.min_weight, config.max_weight, 100);
492 let mut penalties = Vec::new();
493
494 for &weight in &weights {
495 let penalty_values: Vec<f64> = violations
496 .iter()
497 .map(|&v| calculate_penalty(v, weight, config.penalty_type))
498 .collect();
499
500 penalties.push(PenaltyPoint {
501 weight,
502 avg_penalty: penalty_values.iter().sum::<f64>() / penalty_values.len() as f64,
503 max_penalty: penalty_values.iter().fold(0.0, |a, &b| a.max(b)),
504 min_penalty: penalty_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
505 });
506 }
507
508 PenaltyAnalysis {
509 penalty_points: penalties,
510 optimal_weight: find_optimal_weight(&weights, violations, config),
511 sensitivity: calculate_sensitivity(violations, config),
512 }
513}
514
515fn calculate_penalty(violation: f64, weight: f64, penalty_type: PenaltyType) -> f64 {
517 weight
518 * match penalty_type {
519 PenaltyType::Quadratic => violation.powi(2),
520 PenaltyType::Linear => violation.abs(),
521 PenaltyType::LogBarrier => {
522 if violation > 0.0 {
523 -violation.ln()
524 } else {
525 1e10 }
527 }
528 PenaltyType::Exponential => violation.exp_m1(),
529 PenaltyType::AugmentedLagrangian => violation.mul_add(violation, violation.abs()),
530 }
531}
532
533fn find_optimal_weight(weights: &Array1<f64>, violations: &[f64], config: &PenaltyConfig) -> f64 {
535 let target_penalty = violations.len() as f64 * config.violation_tolerance;
538
539 let mut best_weight = config.initial_weight;
540 let mut best_diff = f64::INFINITY;
541
542 for &weight in weights {
543 let total_penalty: f64 = violations
544 .iter()
545 .map(|&v| calculate_penalty(v, weight, config.penalty_type))
546 .sum();
547
548 let diff = (total_penalty - target_penalty).abs();
549 if diff < best_diff {
550 best_diff = diff;
551 best_weight = weight;
552 }
553 }
554
555 best_weight
556}
557
558fn calculate_sensitivity(violations: &[f64], config: &PenaltyConfig) -> f64 {
560 if violations.is_empty() {
561 return 0.0;
562 }
563
564 let weight = config.initial_weight;
566 let penalties: Vec<f64> = violations
567 .iter()
568 .map(|&v| calculate_penalty(v, weight, config.penalty_type))
569 .collect();
570
571 let delta = 0.01 * weight;
572 let penalties_delta: Vec<f64> = violations
573 .iter()
574 .map(|&v| calculate_penalty(v, weight + delta, config.penalty_type))
575 .collect();
576
577 let derivatives: Vec<f64> = penalties
578 .iter()
579 .zip(penalties_delta.iter())
580 .map(|(&p1, &p2)| (p2 - p1) / delta)
581 .collect();
582
583 derivatives.iter().sum::<f64>() / derivatives.len() as f64
585}
586
587#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct PenaltyAnalysis {
590 pub penalty_points: Vec<PenaltyPoint>,
591 pub optimal_weight: f64,
592 pub sensitivity: f64,
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize)]
597pub struct PenaltyPoint {
598 pub weight: f64,
599 pub avg_penalty: f64,
600 pub max_penalty: f64,
601 pub min_penalty: f64,
602}