1use scirs2_core::random::{ChaCha8Rng, Rng, SeedableRng};
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use std::time::{Duration, Instant};
45
46use crate::applications::ApplicationResult;
47use crate::ising::IsingModel;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub enum ConstraintType {
52 Equality,
54 LessThanOrEqual,
56 GreaterThanOrEqual,
58 Custom,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
64pub enum ConstraintPriority {
65 Low,
67 Medium,
69 High,
71 Critical,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
77pub enum PenaltyStrategy {
78 Static,
80 Multiplicative,
82 Additive,
84 Adaptive,
86 Exponential,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
92pub enum RelaxationStrategy {
93 None,
95 Linear,
97 Exponential,
99 Adaptive,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct Constraint {
106 pub id: String,
108
109 pub constraint_type: ConstraintType,
111
112 pub variables: Vec<usize>,
114
115 pub target_value: f64,
117
118 pub penalty_coefficient: f64,
120
121 pub initial_penalty: f64,
123
124 pub priority: ConstraintPriority,
126
127 pub is_soft: bool,
129
130 pub tolerance: f64,
132
133 pub violation_count: usize,
135
136 pub cumulative_violation: f64,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct AdaptiveConstraintConfig {
143 pub initial_penalty: f64,
145
146 pub min_penalty: f64,
148
149 pub max_penalty: f64,
151
152 pub penalty_strategy: PenaltyStrategy,
154
155 pub penalty_increase_factor: f64,
157
158 pub penalty_decrease_factor: f64,
160
161 pub relaxation_strategy: RelaxationStrategy,
163
164 pub relaxation_rate: f64,
166
167 pub max_adaptation_iterations: usize,
169
170 pub violation_tolerance: f64,
172
173 pub enable_tightening: bool,
175
176 pub enable_prioritization: bool,
178
179 pub history_window: usize,
181}
182
183#[derive(Debug, Clone, Default, Serialize, Deserialize)]
185pub struct ConstraintStatistics {
186 pub total_constraints: usize,
188
189 pub hard_constraints: usize,
191
192 pub soft_constraints: usize,
194
195 pub total_violations: usize,
197
198 pub current_violations: usize,
200
201 pub avg_penalty_coefficient: f64,
203
204 pub max_violation: f64,
206
207 pub avg_violation: f64,
209
210 pub satisfaction_rate: f64,
212
213 pub num_adaptations: usize,
215
216 pub num_relaxations: usize,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ViolationRecord {
223 pub constraint_id: String,
225
226 pub iteration: usize,
228
229 pub magnitude: f64,
231
232 pub penalty: f64,
234
235 pub resolved: bool,
237}
238
239pub struct AdaptiveConstraintHandler {
243 pub config: AdaptiveConstraintConfig,
245
246 pub constraints: HashMap<String, Constraint>,
248
249 pub violation_history: Vec<ViolationRecord>,
251
252 pub statistics: ConstraintStatistics,
254
255 pub current_iteration: usize,
257
258 rng: ChaCha8Rng,
260}
261
262impl Constraint {
263 pub fn new(
265 id: impl Into<String>,
266 constraint_type: ConstraintType,
267 variables: Vec<usize>,
268 target_value: f64,
269 priority: ConstraintPriority,
270 ) -> Self {
271 let initial_penalty = match priority {
272 ConstraintPriority::Low => 1.0,
273 ConstraintPriority::Medium => 10.0,
274 ConstraintPriority::High => 100.0,
275 ConstraintPriority::Critical => 1000.0,
276 };
277
278 Self {
279 id: id.into(),
280 constraint_type,
281 variables,
282 target_value,
283 penalty_coefficient: initial_penalty,
284 initial_penalty,
285 priority,
286 is_soft: priority != ConstraintPriority::Critical,
287 tolerance: 0.01,
288 violation_count: 0,
289 cumulative_violation: 0.0,
290 }
291 }
292
293 #[must_use]
295 pub fn evaluate(&self, solution: &[i8]) -> f64 {
296 let sum: i8 = self
298 .variables
299 .iter()
300 .filter_map(|&idx| solution.get(idx))
301 .sum();
302
303 let value = f64::from(sum);
304
305 match self.constraint_type {
306 ConstraintType::Equality => (value - self.target_value).abs(),
307 ConstraintType::LessThanOrEqual => (value - self.target_value).max(0.0),
308 ConstraintType::GreaterThanOrEqual => (self.target_value - value).max(0.0),
309 ConstraintType::Custom => {
310 (value - self.target_value).abs()
313 }
314 }
315 }
316
317 #[must_use]
319 pub fn is_violated(&self, solution: &[i8]) -> bool {
320 let violation = self.evaluate(solution);
321
322 if self.is_soft {
323 violation > self.tolerance
324 } else {
325 violation > 1e-10
326 }
327 }
328
329 #[must_use]
331 pub fn penalty_term(&self, solution: &[i8]) -> f64 {
332 let violation = self.evaluate(solution);
333 self.penalty_coefficient * violation * violation
334 }
335}
336
337impl Default for AdaptiveConstraintConfig {
338 fn default() -> Self {
339 Self {
340 initial_penalty: 10.0,
341 min_penalty: 0.1,
342 max_penalty: 10_000.0,
343 penalty_strategy: PenaltyStrategy::Adaptive,
344 penalty_increase_factor: 1.5,
345 penalty_decrease_factor: 0.9,
346 relaxation_strategy: RelaxationStrategy::Adaptive,
347 relaxation_rate: 0.01,
348 max_adaptation_iterations: 1000,
349 violation_tolerance: 0.01,
350 enable_tightening: true,
351 enable_prioritization: true,
352 history_window: 50,
353 }
354 }
355}
356
357impl AdaptiveConstraintHandler {
358 #[must_use]
360 pub fn new(config: AdaptiveConstraintConfig) -> Self {
361 Self {
362 config,
363 constraints: HashMap::new(),
364 violation_history: Vec::new(),
365 statistics: ConstraintStatistics::default(),
366 current_iteration: 0,
367 rng: ChaCha8Rng::from_seed([0u8; 32]),
368 }
369 }
370
371 pub fn add_constraint(&mut self, constraint: Constraint) {
373 let id = constraint.id.clone();
374 self.constraints.insert(id, constraint);
375 self.update_constraint_counts();
376 }
377
378 pub fn remove_constraint(&mut self, id: &str) -> Option<Constraint> {
380 let result = self.constraints.remove(id);
381 self.update_constraint_counts();
382 result
383 }
384
385 #[must_use]
387 pub fn evaluate_all(&self, solution: &[i8]) -> HashMap<String, f64> {
388 self.constraints
389 .iter()
390 .map(|(id, constraint)| (id.clone(), constraint.evaluate(solution)))
391 .collect()
392 }
393
394 #[must_use]
396 pub fn is_feasible(&self, solution: &[i8]) -> bool {
397 self.constraints.values().all(|c| !c.is_violated(solution))
398 }
399
400 #[must_use]
402 pub fn total_penalty(&self, solution: &[i8]) -> f64 {
403 self.constraints
404 .values()
405 .map(|c| c.penalty_term(solution))
406 .sum()
407 }
408
409 pub fn adapt_penalties(&mut self, solution: &[i8]) {
411 self.current_iteration += 1;
412
413 let violations = self.evaluate_all(solution);
414
415 let adaptation_data: Vec<(String, f64, bool)> = violations
417 .iter()
418 .filter_map(|(id, violation)| {
419 self.constraints.get(id).map(|constraint| {
420 let is_violated = *violation > constraint.tolerance;
421 (id.clone(), *violation, is_violated)
422 })
423 })
424 .collect();
425
426 let config = self.config.clone();
428 let current_iteration = self.current_iteration;
429
430 for (id, violation, is_violated) in adaptation_data {
432 if let Some(constraint) = self.constraints.get_mut(&id) {
433 if is_violated {
434 constraint.violation_count += 1;
436 constraint.cumulative_violation += violation;
437
438 self.violation_history.push(ViolationRecord {
439 constraint_id: id.clone(),
440 iteration: self.current_iteration,
441 magnitude: violation,
442 penalty: constraint.penalty_coefficient,
443 resolved: false,
444 });
445
446 let priority_factor = match constraint.priority {
448 ConstraintPriority::Low => 1.0,
449 ConstraintPriority::Medium => 1.5,
450 ConstraintPriority::High => 2.0,
451 ConstraintPriority::Critical => 3.0,
452 };
453
454 Self::apply_penalty_increase(constraint, violation, priority_factor, &config);
455 } else if constraint.violation_count > 0 {
456 Self::apply_penalty_decrease(constraint, current_iteration, &config);
458 }
459 }
460 }
461
462 self.update_statistics(solution);
464 self.statistics.num_adaptations += 1;
465 }
466
467 fn apply_penalty_increase(
469 constraint: &mut Constraint,
470 violation: f64,
471 priority_factor: f64,
472 config: &AdaptiveConstraintConfig,
473 ) {
474 match config.penalty_strategy {
475 PenaltyStrategy::Static => {
476 }
478 PenaltyStrategy::Multiplicative => {
479 constraint.penalty_coefficient *= config.penalty_increase_factor;
480 }
481 PenaltyStrategy::Additive => {
482 let increase = config.initial_penalty * 0.1;
483 constraint.penalty_coefficient += increase;
484 }
485 PenaltyStrategy::Adaptive => {
486 let violation_factor =
488 (violation / constraint.target_value.abs().max(1.0)).min(10.0);
489 let increase_mult = (priority_factor * violation_factor).mul_add(0.1, 1.0);
490 constraint.penalty_coefficient *= increase_mult;
491 }
492 PenaltyStrategy::Exponential => {
493 let exp_factor = (constraint.violation_count as f64 * 0.1).exp();
495 constraint.penalty_coefficient *= exp_factor.min(2.0);
496 }
497 }
498
499 constraint.penalty_coefficient = constraint
501 .penalty_coefficient
502 .max(config.min_penalty)
503 .min(config.max_penalty);
504 }
505
506 fn apply_penalty_decrease(
508 constraint: &mut Constraint,
509 current_iteration: usize,
510 config: &AdaptiveConstraintConfig,
511 ) {
512 match config.penalty_strategy {
513 PenaltyStrategy::Static => {
514 }
516 PenaltyStrategy::Multiplicative | PenaltyStrategy::Exponential => {
517 constraint.penalty_coefficient *= config.penalty_decrease_factor;
518 }
519 PenaltyStrategy::Additive => {
520 let decrease = config.initial_penalty * 0.05;
521 constraint.penalty_coefficient =
522 (constraint.penalty_coefficient - decrease).max(0.0);
523 }
524 PenaltyStrategy::Adaptive => {
525 if current_iteration % 10 == 0 {
527 constraint.penalty_coefficient *= config.penalty_decrease_factor;
528 }
529 }
530 }
531
532 constraint.penalty_coefficient = constraint
534 .penalty_coefficient
535 .max(config.min_penalty)
536 .min(config.max_penalty);
537 }
538
539 pub fn apply_relaxation(&mut self) {
541 if self.config.relaxation_strategy == RelaxationStrategy::None {
542 return;
543 }
544
545 for constraint in self.constraints.values_mut() {
546 if !constraint.is_soft {
547 continue; }
549
550 let relaxation_amount = match self.config.relaxation_strategy {
551 RelaxationStrategy::None => 0.0,
552 RelaxationStrategy::Linear => self.config.relaxation_rate,
553 RelaxationStrategy::Exponential => {
554 constraint.tolerance * self.config.relaxation_rate
555 }
556 RelaxationStrategy::Adaptive => {
557 let violation_rate =
559 constraint.violation_count as f64 / self.current_iteration.max(1) as f64;
560 self.config.relaxation_rate * (1.0 + violation_rate)
561 }
562 };
563
564 constraint.tolerance += relaxation_amount;
565
566 constraint.tolerance = constraint.tolerance.min(1.0);
568 }
569
570 self.statistics.num_relaxations += 1;
571 }
572
573 pub fn apply_to_model(&self, model: &mut IsingModel) -> ApplicationResult<()> {
575 for constraint in self.constraints.values() {
576 for i in 0..constraint.variables.len() {
578 for j in i..constraint.variables.len() {
579 let var_i = constraint.variables[i];
580 let var_j = constraint.variables[j];
581
582 if i == j {
583 let current_bias = model.get_bias(var_i).unwrap_or(0.0);
585 let penalty_bias = constraint.penalty_coefficient * constraint.target_value;
586 model.set_bias(var_i, current_bias - penalty_bias)?;
587 } else {
588 let current_coupling = model.get_coupling(var_i, var_j).unwrap_or(0.0);
590 let penalty_coupling = constraint.penalty_coefficient;
591 model.set_coupling(var_i, var_j, current_coupling + penalty_coupling)?;
592 }
593 }
594 }
595 }
596
597 Ok(())
598 }
599
600 #[must_use]
602 pub const fn get_statistics(&self) -> &ConstraintStatistics {
603 &self.statistics
604 }
605
606 pub fn reset_statistics(&mut self) {
608 self.statistics = ConstraintStatistics::default();
609 self.current_iteration = 0;
610 self.violation_history.clear();
611
612 for constraint in self.constraints.values_mut() {
614 constraint.violation_count = 0;
615 constraint.cumulative_violation = 0.0;
616 constraint.penalty_coefficient = constraint.initial_penalty;
617 }
618
619 self.update_constraint_counts();
620 }
621
622 fn update_constraint_counts(&mut self) {
624 self.statistics.total_constraints = self.constraints.len();
625 self.statistics.hard_constraints = self.constraints.values().filter(|c| !c.is_soft).count();
626 self.statistics.soft_constraints = self.constraints.values().filter(|c| c.is_soft).count();
627 }
628
629 fn update_statistics(&mut self, solution: &[i8]) {
631 let violations = self.evaluate_all(solution);
632
633 let current_violations = violations
634 .values()
635 .zip(self.constraints.values())
636 .filter(|(violation, constraint)| **violation > constraint.tolerance)
637 .count();
638
639 self.statistics.current_violations = current_violations;
640
641 self.statistics.total_violations =
642 self.constraints.values().map(|c| c.violation_count).sum();
643
644 self.statistics.avg_penalty_coefficient = self
645 .constraints
646 .values()
647 .map(|c| c.penalty_coefficient)
648 .sum::<f64>()
649 / self.constraints.len().max(1) as f64;
650
651 let max_violation = violations.values().copied().fold(0.0, f64::max);
652 self.statistics.max_violation = max_violation;
653
654 let avg_violation = violations.values().sum::<f64>() / violations.len().max(1) as f64;
655 self.statistics.avg_violation = avg_violation;
656
657 let satisfied_count = self.constraints.len() - current_violations;
658 self.statistics.satisfaction_rate =
659 satisfied_count as f64 / self.constraints.len().max(1) as f64;
660 }
661
662 #[must_use]
664 pub fn get_constraint_history(&self, constraint_id: &str) -> Vec<&ViolationRecord> {
665 self.violation_history
666 .iter()
667 .filter(|r| r.constraint_id == constraint_id)
668 .collect()
669 }
670
671 #[must_use]
673 pub fn get_most_violated_constraints(&self, top_k: usize) -> Vec<&Constraint> {
674 let mut constraints: Vec<&Constraint> = self.constraints.values().collect();
675 constraints.sort_by(|a, b| b.violation_count.cmp(&a.violation_count));
676 constraints.into_iter().take(top_k).collect()
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_constraint_creation() {
686 let constraint = Constraint::new(
687 "test_constraint",
688 ConstraintType::Equality,
689 vec![0, 1, 2],
690 2.0,
691 ConstraintPriority::High,
692 );
693
694 assert_eq!(constraint.id, "test_constraint");
695 assert_eq!(constraint.constraint_type, ConstraintType::Equality);
696 assert_eq!(constraint.variables.len(), 3);
697 assert_eq!(constraint.target_value, 2.0);
698 assert!(constraint.is_soft);
699 }
700
701 #[test]
702 fn test_constraint_evaluation() {
703 let constraint = Constraint::new(
704 "sum_constraint",
705 ConstraintType::Equality,
706 vec![0, 1, 2],
707 2.0,
708 ConstraintPriority::Medium,
709 );
710
711 let solution = vec![1, 1, 0, 0];
712 let violation = constraint.evaluate(&solution);
713
714 assert_eq!(violation, 0.0);
716 }
717
718 #[test]
719 fn test_constraint_violation_detection() {
720 let constraint = Constraint::new(
721 "test",
722 ConstraintType::Equality,
723 vec![0, 1],
724 1.0,
725 ConstraintPriority::High,
726 );
727
728 let satisfied_solution = vec![1, 0, 0, 0];
729 let violated_solution = vec![1, 1, 0, 0];
730
731 assert!(!constraint.is_violated(&satisfied_solution));
732 assert!(constraint.is_violated(&violated_solution));
733 }
734
735 #[test]
736 fn test_handler_creation() {
737 let config = AdaptiveConstraintConfig::default();
738 let handler = AdaptiveConstraintHandler::new(config);
739
740 assert_eq!(handler.constraints.len(), 0);
741 assert_eq!(handler.current_iteration, 0);
742 }
743
744 #[test]
745 fn test_add_remove_constraints() {
746 let config = AdaptiveConstraintConfig::default();
747 let mut handler = AdaptiveConstraintHandler::new(config);
748
749 let constraint = Constraint::new(
750 "c1",
751 ConstraintType::Equality,
752 vec![0, 1],
753 1.0,
754 ConstraintPriority::Medium,
755 );
756
757 handler.add_constraint(constraint);
758 assert_eq!(handler.constraints.len(), 1);
759
760 handler.remove_constraint("c1");
761 assert_eq!(handler.constraints.len(), 0);
762 }
763
764 #[test]
765 fn test_feasibility_check() {
766 let config = AdaptiveConstraintConfig::default();
767 let mut handler = AdaptiveConstraintHandler::new(config);
768
769 let constraint = Constraint::new(
770 "c1",
771 ConstraintType::Equality,
772 vec![0, 1],
773 1.0,
774 ConstraintPriority::High,
775 );
776
777 handler.add_constraint(constraint);
778
779 let feasible_solution = vec![1, 0, 0];
780 let infeasible_solution = vec![1, 1, 0];
781
782 assert!(handler.is_feasible(&feasible_solution));
783 assert!(!handler.is_feasible(&infeasible_solution));
784 }
785
786 #[test]
787 fn test_penalty_adaptation() {
788 let config = AdaptiveConstraintConfig {
789 penalty_strategy: PenaltyStrategy::Multiplicative,
790 penalty_increase_factor: 2.0,
791 ..Default::default()
792 };
793
794 let mut handler = AdaptiveConstraintHandler::new(config);
795
796 let constraint = Constraint::new(
797 "c1",
798 ConstraintType::Equality,
799 vec![0, 1],
800 1.0,
801 ConstraintPriority::Medium,
802 );
803
804 let initial_penalty = constraint.penalty_coefficient;
805 handler.add_constraint(constraint);
806
807 let violated_solution = vec![1, 1, 0];
809 handler.adapt_penalties(&violated_solution);
810
811 let updated_constraint = handler
812 .constraints
813 .get("c1")
814 .expect("constraint 'c1' should exist");
815 assert!(updated_constraint.penalty_coefficient > initial_penalty);
816 }
817
818 #[test]
819 fn test_total_penalty_calculation() {
820 let config = AdaptiveConstraintConfig::default();
821 let mut handler = AdaptiveConstraintHandler::new(config);
822
823 let c1 = Constraint::new(
824 "c1",
825 ConstraintType::Equality,
826 vec![0, 1],
827 1.0,
828 ConstraintPriority::Medium,
829 );
830 let c2 = Constraint::new(
831 "c2",
832 ConstraintType::Equality,
833 vec![2, 3],
834 1.0,
835 ConstraintPriority::Medium,
836 );
837
838 handler.add_constraint(c1);
839 handler.add_constraint(c2);
840
841 let solution = vec![1, 1, 1, 1];
842 let total_penalty = handler.total_penalty(&solution);
843
844 assert!(total_penalty > 0.0);
846 }
847
848 #[test]
849 fn test_statistics_tracking() {
850 let config = AdaptiveConstraintConfig::default();
851 let mut handler = AdaptiveConstraintHandler::new(config);
852
853 let constraint = Constraint::new(
854 "c1",
855 ConstraintType::Equality,
856 vec![0, 1],
857 1.0,
858 ConstraintPriority::High,
859 );
860 handler.add_constraint(constraint);
861
862 let stats = handler.get_statistics();
863 assert_eq!(stats.total_constraints, 1);
864 assert_eq!(stats.soft_constraints, 1);
865 }
866
867 #[test]
868 fn test_constraint_relaxation() {
869 let config = AdaptiveConstraintConfig {
870 relaxation_strategy: RelaxationStrategy::Linear,
871 relaxation_rate: 0.1,
872 ..Default::default()
873 };
874
875 let mut handler = AdaptiveConstraintHandler::new(config);
876
877 let constraint = Constraint::new(
878 "c1",
879 ConstraintType::Equality,
880 vec![0, 1],
881 1.0,
882 ConstraintPriority::Low,
883 );
884 let initial_tolerance = constraint.tolerance;
885 handler.add_constraint(constraint);
886
887 handler.apply_relaxation();
888
889 let updated_constraint = handler
890 .constraints
891 .get("c1")
892 .expect("constraint 'c1' should exist");
893 assert!(updated_constraint.tolerance > initial_tolerance);
894 }
895
896 #[test]
897 fn test_violation_history() {
898 let config = AdaptiveConstraintConfig::default();
899 let mut handler = AdaptiveConstraintHandler::new(config);
900
901 let constraint = Constraint::new(
902 "c1",
903 ConstraintType::Equality,
904 vec![0, 1],
905 1.0,
906 ConstraintPriority::Medium,
907 );
908 handler.add_constraint(constraint);
909
910 let violated_solution = vec![1, 1, 0];
911 handler.adapt_penalties(&violated_solution);
912
913 assert!(!handler.violation_history.is_empty());
914
915 let history = handler.get_constraint_history("c1");
916 assert!(!history.is_empty());
917 }
918}