quantrs2_tytan/optimization/
penalty.rs

1//! Penalty function optimization for QUBO problems
2//!
3//! This module provides advanced penalty weight optimization using SciRS2
4//! for automatic tuning and constraint satisfaction analysis.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[cfg(feature = "scirs")]
11use crate::scirs_stub::{
12    scirs2_linalg::norm::Norm,
13    scirs2_optimization::{OptimizationProblem, Optimizer},
14};
15
16/// Penalty function configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PenaltyConfig {
19    /// Initial penalty weight
20    pub initial_weight: f64,
21    /// Minimum penalty weight
22    pub min_weight: f64,
23    /// Maximum penalty weight
24    pub max_weight: f64,
25    /// Weight adjustment factor
26    pub adjustment_factor: f64,
27    /// Target constraint violation tolerance
28    pub violation_tolerance: f64,
29    /// Maximum optimization iterations
30    pub max_iterations: usize,
31    /// Use adaptive penalty scaling
32    pub adaptive_scaling: bool,
33    /// Penalty function type
34    pub penalty_type: PenaltyType,
35}
36
37/// Types of penalty functions
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum PenaltyType {
40    /// Quadratic penalty: weight * violation^2
41    Quadratic,
42    /// Linear penalty: weight * |violation|
43    Linear,
44    /// Logarithmic barrier: -weight * log(slack)
45    LogBarrier,
46    /// Exponential penalty: weight * exp(violation) - 1
47    Exponential,
48    /// Augmented Lagrangian method
49    AugmentedLagrangian,
50}
51
52impl Default for PenaltyConfig {
53    fn default() -> Self {
54        Self {
55            initial_weight: 1.0,
56            min_weight: 0.001,
57            max_weight: 1000.0,
58            adjustment_factor: 2.0,
59            violation_tolerance: 1e-6,
60            max_iterations: 100,
61            adaptive_scaling: true,
62            penalty_type: PenaltyType::Quadratic,
63        }
64    }
65}
66
67/// Penalty function optimizer
68pub struct PenaltyOptimizer {
69    config: PenaltyConfig,
70    constraint_weights: HashMap<String, f64>,
71    violation_history: Vec<ConstraintViolation>,
72    #[cfg(feature = "scirs")]
73    optimizer: Option<Box<dyn Optimizer>>,
74}
75
76/// Constraint violation information
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ConstraintViolation {
79    pub constraint_name: String,
80    pub violation_amount: f64,
81    pub penalty_weight: f64,
82    pub iteration: usize,
83}
84
85/// Penalty optimization result
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PenaltyOptimizationResult {
88    pub optimal_weights: HashMap<String, f64>,
89    pub final_violations: HashMap<String, f64>,
90    pub converged: bool,
91    pub iterations: usize,
92    pub objective_value: f64,
93    pub constraint_satisfaction: f64,
94}
95
96impl PenaltyOptimizer {
97    /// Create new penalty optimizer
98    pub fn new(config: PenaltyConfig) -> Self {
99        Self {
100            config,
101            constraint_weights: HashMap::new(),
102            violation_history: Vec::new(),
103            #[cfg(feature = "scirs")]
104            optimizer: None,
105        }
106    }
107
108    /// Initialize constraint weights
109    pub fn initialize_weights(&mut self, constraints: &[String]) {
110        for constraint in constraints {
111            self.constraint_weights
112                .insert(constraint.clone(), self.config.initial_weight);
113        }
114
115        #[cfg(feature = "scirs")]
116        {
117            // Initialize SciRS2 optimizer
118            use crate::scirs_stub::scirs2_optimization::gradient::LBFGS;
119            self.optimizer = Some(Box::new(LBFGS::new(constraints.len())));
120        }
121    }
122
123    /// Optimize penalty weights for a compiled model
124    pub fn optimize_penalties(
125        &mut self,
126        model: &CompiledModel,
127        sample_results: &[(Vec<bool>, f64)],
128    ) -> Result<PenaltyOptimizationResult, Box<dyn std::error::Error>> {
129        let mut iteration = 0;
130        let mut converged = false;
131
132        while iteration < self.config.max_iterations && !converged {
133            // Evaluate constraint violations
134            let violations = self.evaluate_violations(model, sample_results)?;
135
136            // Check convergence
137            let max_violation = violations.values().map(|v| v.abs()).fold(0.0, f64::max);
138
139            if max_violation < self.config.violation_tolerance {
140                converged = true;
141                break;
142            }
143
144            // Update penalty weights
145            self.update_weights(&violations, iteration)?;
146
147            // Record history
148            for (name, &violation) in &violations {
149                self.violation_history.push(ConstraintViolation {
150                    constraint_name: name.clone(),
151                    violation_amount: violation,
152                    penalty_weight: self.constraint_weights[name],
153                    iteration,
154                });
155            }
156
157            iteration += 1;
158        }
159
160        // Calculate final metrics
161        let final_violations = self.evaluate_violations(model, sample_results)?;
162        let objective_value = self.calculate_objective(model, sample_results)?;
163        let constraint_satisfaction = self.calculate_satisfaction_rate(&final_violations);
164
165        Ok(PenaltyOptimizationResult {
166            optimal_weights: self.constraint_weights.clone(),
167            final_violations,
168            converged,
169            iterations: iteration,
170            objective_value,
171            constraint_satisfaction,
172        })
173    }
174
175    /// Evaluate constraint violations
176    fn evaluate_violations(
177        &self,
178        model: &CompiledModel,
179        sample_results: &[(Vec<bool>, f64)],
180    ) -> Result<HashMap<String, f64>, Box<dyn std::error::Error>> {
181        let mut violations = HashMap::new();
182
183        // For each constraint in the model
184        for (constraint_name, constraint_expr) in model.get_constraints() {
185            let mut total_violation = 0.0;
186            let mut count = 0;
187
188            // Evaluate constraint for each sample
189            for (assignment, _energy) in sample_results {
190                let violation = self.evaluate_constraint_violation(
191                    constraint_expr,
192                    assignment,
193                    model.get_variable_map(),
194                )?;
195
196                total_violation += violation;
197                count += 1;
198            }
199
200            // Average violation
201            violations.insert(
202                constraint_name.clone(),
203                if count > 0 {
204                    total_violation / count as f64
205                } else {
206                    0.0
207                },
208            );
209        }
210
211        Ok(violations)
212    }
213
214    /// Evaluate single constraint violation
215    fn evaluate_constraint_violation(
216        &self,
217        _constraint: &ConstraintExpr,
218        _assignment: &[bool],
219        _var_map: &HashMap<String, usize>,
220    ) -> Result<f64, Box<dyn std::error::Error>> {
221        // Placeholder evaluation - in real implementation would parse and evaluate expression
222        let value: f64 = 0.0; // Placeholder
223
224        // Calculate violation based on constraint type
225        Ok(match self.config.penalty_type {
226            PenaltyType::Quadratic => value.powi(2),
227            PenaltyType::Linear => value.abs(),
228            PenaltyType::LogBarrier => {
229                if value > 0.0 {
230                    -value.ln()
231                } else {
232                    f64::INFINITY
233                }
234            }
235            PenaltyType::Exponential => value.exp_m1(),
236            PenaltyType::AugmentedLagrangian => {
237                // Simplified augmented Lagrangian
238                value.mul_add(value, value.abs())
239            }
240        })
241    }
242
243    /// Update penalty weights based on violations
244    fn update_weights(
245        &mut self,
246        violations: &HashMap<String, f64>,
247        iteration: usize,
248    ) -> Result<(), Box<dyn std::error::Error>> {
249        #[cfg(feature = "scirs")]
250        {
251            if self.config.adaptive_scaling && self.optimizer.is_some() {
252                // Use SciRS2 optimizer for weight updates
253                self.update_weights_optimized(violations, iteration)?;
254                return Ok(());
255            }
256        }
257
258        // Standard weight update
259        for (constraint_name, &violation) in violations {
260            if let Some(weight) = self.constraint_weights.get_mut(constraint_name) {
261                if violation.abs() > self.config.violation_tolerance {
262                    // Increase penalty weight
263                    *weight = (*weight * self.config.adjustment_factor).min(self.config.max_weight);
264                } else if violation.abs() < self.config.violation_tolerance * 0.1 {
265                    // Decrease penalty weight if over-penalized
266                    *weight = (*weight / self.config.adjustment_factor.sqrt())
267                        .max(self.config.min_weight);
268                }
269            }
270        }
271
272        Ok(())
273    }
274
275    #[cfg(feature = "scirs")]
276    /// Update weights using SciRS2 optimization
277    fn update_weights_optimized(
278        &mut self,
279        violations: &HashMap<String, f64>,
280        iteration: usize,
281    ) -> Result<(), Box<dyn std::error::Error>> {
282        use crate::scirs_stub::scirs2_optimization::{Bounds, ObjectiveFunction};
283
284        // Define optimization problem
285        let constraint_names: Vec<_> = violations.keys().cloned().collect();
286        let current_weights: Array1<f64> = constraint_names
287            .iter()
288            .map(|name| self.constraint_weights[name])
289            .collect();
290
291        // Objective: minimize total weighted violations
292        let violations_vec: Array1<f64> = constraint_names
293            .iter()
294            .map(|name| violations[name].abs())
295            .collect();
296
297        let mut objective = WeightOptimizationObjective {
298            violations: violations_vec,
299            penalty_type: self.config.penalty_type,
300            regularization: 0.01, // L2 regularization on weights
301        };
302
303        // Set bounds
304        let lower_bounds = Array1::from_elem(constraint_names.len(), self.config.min_weight);
305        let upper_bounds = Array1::from_elem(constraint_names.len(), self.config.max_weight);
306        let bounds = Bounds::new(lower_bounds, upper_bounds);
307
308        // Optimize
309        if let Some(ref mut optimizer) = self.optimizer {
310            let mut result =
311                optimizer.minimize(&objective, &current_weights, &bounds, iteration)?;
312
313            // Update weights
314            for (i, name) in constraint_names.iter().enumerate() {
315                self.constraint_weights.insert(name.clone(), result.x[i]);
316            }
317        }
318
319        Ok(())
320    }
321
322    /// Calculate objective value
323    fn calculate_objective(
324        &self,
325        model: &CompiledModel,
326        sample_results: &[(Vec<bool>, f64)],
327    ) -> Result<f64, Box<dyn std::error::Error>> {
328        let mut total_objective = 0.0;
329
330        for (assignment, energy) in sample_results {
331            // Original objective
332            let mut penalized_objective = *energy;
333
334            // Add penalty terms
335            for (constraint_name, constraint_expr) in model.get_constraints() {
336                let violation = self.evaluate_constraint_violation(
337                    constraint_expr,
338                    assignment,
339                    model.get_variable_map(),
340                )?;
341
342                let weight = self
343                    .constraint_weights
344                    .get(constraint_name)
345                    .copied()
346                    .unwrap_or(1.0);
347
348                penalized_objective += weight * violation;
349            }
350
351            total_objective += penalized_objective;
352        }
353
354        Ok(total_objective / sample_results.len() as f64)
355    }
356
357    /// Calculate constraint satisfaction rate
358    fn calculate_satisfaction_rate(&self, violations: &HashMap<String, f64>) -> f64 {
359        let satisfied = violations
360            .values()
361            .filter(|&&v| v.abs() < self.config.violation_tolerance)
362            .count();
363
364        if violations.is_empty() {
365            1.0
366        } else {
367            satisfied as f64 / violations.len() as f64
368        }
369    }
370
371    /// Get penalty weight for a constraint
372    pub fn get_weight(&self, constraint_name: &str) -> Option<f64> {
373        self.constraint_weights.get(constraint_name).copied()
374    }
375
376    /// Get violation history
377    pub fn get_violation_history(&self) -> &[ConstraintViolation] {
378        &self.violation_history
379    }
380
381    /// Export penalty configuration
382    pub fn export_config(&self) -> PenaltyExport {
383        PenaltyExport {
384            config: self.config.clone(),
385            weights: self.constraint_weights.clone(),
386            final_violations: self
387                .violation_history
388                .iter()
389                .filter(|v| {
390                    v.iteration
391                        == self
392                            .violation_history
393                            .iter()
394                            .map(|h| h.iteration)
395                            .max()
396                            .unwrap_or(0)
397                })
398                .map(|v| (v.constraint_name.clone(), v.violation_amount))
399                .collect(),
400        }
401    }
402}
403
404/// Exported penalty configuration
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct PenaltyExport {
407    pub config: PenaltyConfig,
408    pub weights: HashMap<String, f64>,
409    pub final_violations: HashMap<String, f64>,
410}
411
412#[cfg(feature = "scirs")]
413/// Objective function for weight optimization
414struct WeightOptimizationObjective {
415    violations: Array1<f64>,
416    penalty_type: PenaltyType,
417    regularization: f64,
418}
419
420#[cfg(feature = "scirs")]
421impl crate::scirs_stub::scirs2_optimization::ObjectiveFunction for WeightOptimizationObjective {
422    fn evaluate(&self, weights: &Array1<f64>) -> f64 {
423        // Weighted sum of violations
424        let weighted_violations = weights * &self.violations;
425        let total_violation = weighted_violations.sum();
426
427        // Add regularization term
428        let regularization = self.regularization * weights.dot(weights);
429
430        total_violation + regularization
431    }
432
433    fn gradient(&self, weights: &Array1<f64>) -> Array1<f64> {
434        // Gradient of weighted violations plus regularization
435        &self.violations + 2.0 * self.regularization * weights
436    }
437}
438
439/// Compiled model placeholder
440#[derive(Debug, Clone)]
441pub struct CompiledModel {
442    constraints: HashMap<String, ConstraintExpr>,
443    variable_map: HashMap<String, usize>,
444}
445
446impl Default for CompiledModel {
447    fn default() -> Self {
448        Self::new()
449    }
450}
451
452impl CompiledModel {
453    pub fn new() -> Self {
454        Self {
455            constraints: HashMap::new(),
456            variable_map: HashMap::new(),
457        }
458    }
459
460    pub const fn get_constraints(&self) -> &HashMap<String, ConstraintExpr> {
461        &self.constraints
462    }
463
464    pub const fn get_variable_map(&self) -> &HashMap<String, usize> {
465        &self.variable_map
466    }
467
468    pub fn to_qubo(&self) -> (Array2<f64>, HashMap<String, usize>) {
469        let size = self.variable_map.len();
470        (Array2::zeros((size, size)), self.variable_map.clone())
471    }
472}
473
474/// Constraint expression placeholder
475#[derive(Debug, Clone)]
476pub struct ConstraintExpr {
477    pub expression: String,
478}
479
480// Helper trait extension for Term evaluation
481trait TermEvaluator {
482    fn evaluate_with_assignment(
483        &self,
484        assignment: &[bool],
485        var_map: &HashMap<String, usize>,
486    ) -> Result<f64, Box<dyn std::error::Error>>;
487}
488
489/// Analyze penalty function behavior
490pub fn analyze_penalty_landscape(config: &PenaltyConfig, violations: &[f64]) -> PenaltyAnalysis {
491    let weights = Array1::linspace(config.min_weight, config.max_weight, 100);
492    let mut penalties = Vec::new();
493
494    for &weight in &weights {
495        let penalty_values: Vec<f64> = violations
496            .iter()
497            .map(|&v| calculate_penalty(v, weight, config.penalty_type))
498            .collect();
499
500        penalties.push(PenaltyPoint {
501            weight,
502            avg_penalty: penalty_values.iter().sum::<f64>() / penalty_values.len() as f64,
503            max_penalty: penalty_values.iter().fold(0.0, |a, &b| a.max(b)),
504            min_penalty: penalty_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
505        });
506    }
507
508    PenaltyAnalysis {
509        penalty_points: penalties,
510        optimal_weight: find_optimal_weight(&weights, violations, config),
511        sensitivity: calculate_sensitivity(violations, config),
512    }
513}
514
515/// Calculate penalty value
516fn calculate_penalty(violation: f64, weight: f64, penalty_type: PenaltyType) -> f64 {
517    weight
518        * match penalty_type {
519            PenaltyType::Quadratic => violation.powi(2),
520            PenaltyType::Linear => violation.abs(),
521            PenaltyType::LogBarrier => {
522                if violation > 0.0 {
523                    -violation.ln()
524                } else {
525                    1e10 // Large penalty for infeasible region
526                }
527            }
528            PenaltyType::Exponential => violation.exp_m1(),
529            PenaltyType::AugmentedLagrangian => violation.mul_add(violation, violation.abs()),
530        }
531}
532
533/// Find optimal penalty weight
534fn find_optimal_weight(weights: &Array1<f64>, violations: &[f64], config: &PenaltyConfig) -> f64 {
535    // Simple heuristic: find weight that balances constraint satisfaction
536    // with objective minimization
537    let target_penalty = violations.len() as f64 * config.violation_tolerance;
538
539    let mut best_weight = config.initial_weight;
540    let mut best_diff = f64::INFINITY;
541
542    for &weight in weights {
543        let total_penalty: f64 = violations
544            .iter()
545            .map(|&v| calculate_penalty(v, weight, config.penalty_type))
546            .sum();
547
548        let diff = (total_penalty - target_penalty).abs();
549        if diff < best_diff {
550            best_diff = diff;
551            best_weight = weight;
552        }
553    }
554
555    best_weight
556}
557
558/// Calculate penalty sensitivity
559fn calculate_sensitivity(violations: &[f64], config: &PenaltyConfig) -> f64 {
560    if violations.is_empty() {
561        return 0.0;
562    }
563
564    // Calculate derivative of penalty w.r.t. weight at current weight
565    let weight = config.initial_weight;
566    let penalties: Vec<f64> = violations
567        .iter()
568        .map(|&v| calculate_penalty(v, weight, config.penalty_type))
569        .collect();
570
571    let delta = 0.01 * weight;
572    let penalties_delta: Vec<f64> = violations
573        .iter()
574        .map(|&v| calculate_penalty(v, weight + delta, config.penalty_type))
575        .collect();
576
577    let derivatives: Vec<f64> = penalties
578        .iter()
579        .zip(penalties_delta.iter())
580        .map(|(&p1, &p2)| (p2 - p1) / delta)
581        .collect();
582
583    // Return average sensitivity
584    derivatives.iter().sum::<f64>() / derivatives.len() as f64
585}
586
587/// Penalty analysis results
588#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct PenaltyAnalysis {
590    pub penalty_points: Vec<PenaltyPoint>,
591    pub optimal_weight: f64,
592    pub sensitivity: f64,
593}
594
595/// Penalty evaluation point
596#[derive(Debug, Clone, Serialize, Deserialize)]
597pub struct PenaltyPoint {
598    pub weight: f64,
599    pub avg_penalty: f64,
600    pub max_penalty: f64,
601    pub min_penalty: f64,
602}