1use crate::{
7 optimization::penalty::CompiledModel,
8 sampler::{SampleResult, Sampler},
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[cfg(feature = "scirs")]
14use crate::scirs_stub::scirs2_core::statistics::{MovingAverage, OnlineStats};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum AdaptiveStrategy {
19 ExponentialDecay,
21 AdaptivePenaltyMethod,
23 AugmentedLagrangian,
25 PopulationBased,
27 MultiArmedBandit,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AdaptiveConfig {
34 pub strategy: AdaptiveStrategy,
35 pub update_interval: usize,
36 pub learning_rate: f64,
37 pub momentum: f64,
38 pub patience: usize,
39 pub exploration_rate: f64,
40 pub population_size: usize,
41 pub history_window: usize,
42}
43
44impl Default for AdaptiveConfig {
45 fn default() -> Self {
46 Self {
47 strategy: AdaptiveStrategy::AdaptivePenaltyMethod,
48 update_interval: 10,
49 learning_rate: 0.1,
50 momentum: 0.9,
51 patience: 5,
52 exploration_rate: 0.1,
53 population_size: 10,
54 history_window: 100,
55 }
56 }
57}
58
59pub struct AdaptiveOptimizer {
61 config: AdaptiveConfig,
62 iteration: usize,
63 parameter_history: Vec<ParameterState>,
64 performance_history: Vec<PerformanceMetrics>,
65 lagrange_multipliers: HashMap<String, f64>,
66 population: Vec<Individual>,
67 #[cfg(feature = "scirs")]
68 stats: OnlineStats,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ParameterState {
74 pub iteration: usize,
75 pub parameters: HashMap<String, f64>,
76 pub penalty_weights: HashMap<String, f64>,
77 pub temperature: f64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct PerformanceMetrics {
83 pub iteration: usize,
84 pub best_energy: f64,
85 pub avg_energy: f64,
86 pub constraint_violations: HashMap<String, f64>,
87 pub feasibility_rate: f64,
88 pub diversity: f64,
89}
90
91#[derive(Debug, Clone)]
93struct Individual {
94 id: usize,
95 parameters: HashMap<String, f64>,
96 fitness: f64,
97 constraint_satisfaction: f64,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct AdaptiveResult {
103 pub final_parameters: HashMap<String, f64>,
104 pub final_penalty_weights: HashMap<String, f64>,
105 pub convergence_history: Vec<f64>,
106 pub constraint_history: Vec<HashMap<String, f64>>,
107 pub total_iterations: usize,
108 pub best_solution: AdaptiveSampleResult,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AdaptiveSampleResult {
114 pub assignments: HashMap<String, bool>,
115 pub energy: f64,
116}
117
118impl AdaptiveOptimizer {
119 pub fn new(config: AdaptiveConfig) -> Self {
121 Self {
122 config,
123 iteration: 0,
124 parameter_history: Vec::new(),
125 performance_history: Vec::new(),
126 lagrange_multipliers: HashMap::new(),
127 population: Vec::new(),
128 #[cfg(feature = "scirs")]
129 stats: OnlineStats::new(),
130 }
131 }
132
133 pub fn optimize<S: Sampler + Clone>(
135 &mut self,
136 mut sampler: S,
137 model: &CompiledModel,
138 initial_params: HashMap<String, f64>,
139 initial_penalties: HashMap<String, f64>,
140 max_iterations: usize,
141 ) -> Result<AdaptiveResult, Box<dyn std::error::Error>> {
142 let mut current_params = initial_params;
144 let mut penalty_weights = initial_penalties;
145 let mut best_solution = None;
146 let mut best_energy = f64::INFINITY;
147
148 match self.config.strategy {
150 AdaptiveStrategy::PopulationBased => {
151 self.initialize_population(¤t_params)?;
152 }
153 AdaptiveStrategy::AugmentedLagrangian => {
154 self.initialize_lagrange_multipliers(&penalty_weights);
155 }
156 _ => {}
157 }
158
159 let mut no_improvement_count = 0;
161
162 for iter in 0..max_iterations {
163 self.iteration = iter;
164
165 let samples =
167 self.run_sampling(&mut sampler, model, ¤t_params, &penalty_weights)?;
168
169 let metrics = self.evaluate_performance(model, &samples)?;
171 self.performance_history.push(metrics.clone());
172
173 if let Some(sample) = samples.iter().min_by(|a, b| {
175 a.energy
176 .partial_cmp(&b.energy)
177 .unwrap_or(std::cmp::Ordering::Equal)
178 }) {
179 if sample.energy < best_energy {
180 best_energy = sample.energy;
181 best_solution = Some(AdaptiveSampleResult {
182 assignments: sample.assignments.clone(),
183 energy: sample.energy,
184 });
185 no_improvement_count = 0;
186 } else {
187 no_improvement_count += 1;
188 }
189 }
190
191 if no_improvement_count > self.config.patience {
193 break;
194 }
195
196 if iter % self.config.update_interval == 0 && iter > 0 {
198 self.update_parameters(&mut current_params, &mut penalty_weights, &metrics)?;
199 }
200
201 self.parameter_history.push(ParameterState {
203 iteration: iter,
204 parameters: current_params.clone(),
205 penalty_weights: penalty_weights.clone(),
206 temperature: self.calculate_temperature(iter, max_iterations),
207 });
208 }
209
210 let convergence_history = self
212 .performance_history
213 .iter()
214 .map(|m| m.best_energy)
215 .collect();
216
217 let constraint_history = self
218 .performance_history
219 .iter()
220 .map(|m| m.constraint_violations.clone())
221 .collect();
222
223 Ok(AdaptiveResult {
224 final_parameters: current_params,
225 final_penalty_weights: penalty_weights,
226 convergence_history,
227 constraint_history,
228 total_iterations: self.iteration,
229 best_solution: best_solution.ok_or("No valid solution found")?,
230 })
231 }
232
233 fn run_sampling<S: Sampler>(
235 &self,
236 sampler: &mut S,
237 model: &CompiledModel,
238 params: &HashMap<String, f64>,
239 penalty_weights: &HashMap<String, f64>,
240 ) -> Result<Vec<SampleResult>, Box<dyn std::error::Error>> {
241 let penalized_model = self.apply_penalties(model, penalty_weights)?;
243
244 sampler.set_parameters(params.clone());
246
247 let num_reads = params.get("num_reads").copied().unwrap_or(100.0) as usize;
249
250 Ok(sampler.run_qubo(&penalized_model.to_qubo(), num_reads)?)
251 }
252
253 fn evaluate_performance(
255 &self,
256 model: &CompiledModel,
257 samples: &[SampleResult],
258 ) -> Result<PerformanceMetrics, Box<dyn std::error::Error>> {
259 let energies: Vec<f64> = samples.iter().map(|s| s.energy).collect();
260 let best_energy = energies.iter().fold(f64::INFINITY, |a, &b| a.min(b));
261 let avg_energy = energies.iter().sum::<f64>() / energies.len() as f64;
262
263 let constraint_violations = self.evaluate_constraint_violations(model, samples)?;
265
266 let feasible_count = samples
268 .iter()
269 .filter(|s| self.is_feasible(s, &constraint_violations).unwrap_or(false))
270 .count();
271 let feasibility_rate = feasible_count as f64 / samples.len() as f64;
272
273 let diversity = self.calculate_diversity(samples);
275
276 Ok(PerformanceMetrics {
277 iteration: self.iteration,
278 best_energy,
279 avg_energy,
280 constraint_violations,
281 feasibility_rate,
282 diversity,
283 })
284 }
285
286 fn update_parameters(
288 &mut self,
289 params: &mut HashMap<String, f64>,
290 penalty_weights: &mut HashMap<String, f64>,
291 metrics: &PerformanceMetrics,
292 ) -> Result<(), Box<dyn std::error::Error>> {
293 match self.config.strategy {
294 AdaptiveStrategy::ExponentialDecay => {
295 self.update_exponential_decay(params, penalty_weights)?;
296 }
297 AdaptiveStrategy::AdaptivePenaltyMethod => {
298 self.update_adaptive_penalty(penalty_weights, metrics)?;
299 }
300 AdaptiveStrategy::AugmentedLagrangian => {
301 self.update_augmented_lagrangian(penalty_weights, metrics)?;
302 }
303 AdaptiveStrategy::PopulationBased => {
304 self.update_population_based(params, penalty_weights, metrics)?;
305 }
306 AdaptiveStrategy::MultiArmedBandit => {
307 self.update_multi_armed_bandit(params, metrics)?;
308 }
309 }
310
311 Ok(())
312 }
313
314 fn update_exponential_decay(
316 &self,
317 params: &mut HashMap<String, f64>,
318 penalty_weights: &mut HashMap<String, f64>,
319 ) -> Result<(), Box<dyn std::error::Error>> {
320 let decay_rate = 0.95;
321
322 if let Some(temp) = params.get_mut("temperature") {
324 *temp *= decay_rate;
325 }
326
327 for weight in penalty_weights.values_mut() {
329 *weight *= 1.0 / decay_rate.sqrt(); }
331
332 Ok(())
333 }
334
335 fn update_adaptive_penalty(
337 &mut self,
338 penalty_weights: &mut HashMap<String, f64>,
339 metrics: &PerformanceMetrics,
340 ) -> Result<(), Box<dyn std::error::Error>> {
341 for (constraint_name, &violation) in &metrics.constraint_violations {
343 if let Some(weight) = penalty_weights.get_mut(constraint_name) {
344 if violation > 1e-6 {
345 *weight *= 1.0 + self.config.learning_rate;
347 } else {
348 *weight *= self.config.learning_rate.mul_add(-0.5, 1.0);
350 }
351
352 *weight = weight.clamp(0.001, 1000.0);
354 }
355 }
356
357 #[cfg(feature = "scirs")]
358 {
359 self.stats.update(metrics.best_energy);
361 }
362
363 Ok(())
364 }
365
366 fn update_augmented_lagrangian(
368 &mut self,
369 penalty_weights: &mut HashMap<String, f64>,
370 metrics: &PerformanceMetrics,
371 ) -> Result<(), Box<dyn std::error::Error>> {
372 for (constraint_name, &violation) in &metrics.constraint_violations {
374 let multiplier = self
375 .lagrange_multipliers
376 .entry(constraint_name.clone())
377 .or_insert(0.0);
378
379 *multiplier += self.config.learning_rate * violation;
381
382 if let Some(weight) = penalty_weights.get_mut(constraint_name) {
384 *weight = 0.5f64.mul_add(weight.sqrt(), multiplier.abs());
385 }
386 }
387
388 Ok(())
389 }
390
391 fn update_population_based(
393 &mut self,
394 params: &mut HashMap<String, f64>,
395 _penalty_weights: &mut HashMap<String, f64>,
396 metrics: &PerformanceMetrics,
397 ) -> Result<(), Box<dyn std::error::Error>> {
398 let fitness_values: Vec<f64> = self
400 .population
401 .iter()
402 .map(|individual| self.evaluate_individual_fitness(individual, metrics))
403 .collect::<Result<Vec<_>, _>>()?;
404
405 for (i, fitness) in fitness_values.into_iter().enumerate() {
406 self.population[i].fitness = fitness;
407 }
408
409 self.population.sort_by(|a, b| {
411 b.fitness
412 .partial_cmp(&a.fitness)
413 .unwrap_or(std::cmp::Ordering::Equal)
414 });
415
416 if let Some(best) = self.population.first() {
418 for (key, value) in &best.parameters {
419 if let Some(param) = params.get_mut(key) {
420 *param = self
421 .config
422 .momentum
423 .mul_add(*param, (1.0 - self.config.momentum) * value);
424 }
425 }
426 }
427
428 let mid = self.population.len() / 2;
430 let pop_len = self.population.len();
431 for i in mid..pop_len {
432 use scirs2_core::random::prelude::*;
434 let mut rng = thread_rng();
435
436 for value in self.population[i].parameters.values_mut() {
437 if rng.gen::<f64>() < 0.3 {
438 let perturbation = rng.gen_range(-0.3..0.3) * value.abs();
439 *value += perturbation;
440 }
441 }
442 }
443
444 Ok(())
445 }
446
447 fn update_multi_armed_bandit(
449 &mut self,
450 params: &mut HashMap<String, f64>,
451 _metrics: &PerformanceMetrics,
452 ) -> Result<(), Box<dyn std::error::Error>> {
453 use scirs2_core::random::prelude::*;
457 let mut rng = thread_rng();
458
459 for (param_name, param_value) in params.iter_mut() {
460 if rng.gen::<f64>() < self.config.exploration_rate {
461 let perturbation = rng.gen_range(-0.1..0.1) * param_value.abs();
463 *param_value += perturbation;
464 } else {
465 if let Some(best_state) = self.parameter_history.iter().min_by(|a, b| {
467 let a_metrics = &self.performance_history[a.iteration];
468 let b_metrics = &self.performance_history[b.iteration];
469 a_metrics
470 .best_energy
471 .partial_cmp(&b_metrics.best_energy)
472 .unwrap_or(std::cmp::Ordering::Equal)
473 }) {
474 if let Some(best_value) = best_state.parameters.get(param_name) {
475 *param_value += self.config.learning_rate * (best_value - *param_value);
476 }
477 }
478 }
479 }
480
481 Ok(())
482 }
483
484 fn initialize_population(
486 &mut self,
487 base_params: &HashMap<String, f64>,
488 ) -> Result<(), Box<dyn std::error::Error>> {
489 use scirs2_core::random::prelude::*;
490 let mut rng = thread_rng();
491
492 for i in 0..self.config.population_size {
493 let mut params = base_params.clone();
494
495 for value in params.values_mut() {
497 let perturbation = rng.gen_range(-0.2..0.2) * value.abs();
498 *value += perturbation;
499 }
500
501 self.population.push(Individual {
502 id: i,
503 parameters: params,
504 fitness: 0.0,
505 constraint_satisfaction: 0.0,
506 });
507 }
508
509 Ok(())
510 }
511
512 fn initialize_lagrange_multipliers(&mut self, penalty_weights: &HashMap<String, f64>) {
514 for (constraint_name, &weight) in penalty_weights {
515 self.lagrange_multipliers
516 .insert(constraint_name.clone(), weight * 0.1);
517 }
518 }
519
520 fn apply_penalties(
522 &self,
523 model: &CompiledModel,
524 _penalty_weights: &HashMap<String, f64>,
525 ) -> Result<CompiledModel, Box<dyn std::error::Error>> {
526 Ok(model.clone())
529 }
530
531 fn evaluate_constraint_violations(
533 &self,
534 model: &CompiledModel,
535 _samples: &[SampleResult],
536 ) -> Result<HashMap<String, f64>, Box<dyn std::error::Error>> {
537 let mut violations = HashMap::new();
539
540 for constraint_name in model.get_constraints().keys() {
541 violations.insert(constraint_name.clone(), 0.0);
542 }
543
544 Ok(violations)
545 }
546
547 fn is_feasible(
549 &self,
550 _sample: &SampleResult,
551 constraint_violations: &HashMap<String, f64>,
552 ) -> Result<bool, Box<dyn std::error::Error>> {
553 let max_violation = constraint_violations
554 .values()
555 .fold(0.0f64, |a, &b| a.max(b.abs()));
556
557 Ok(max_violation < 1e-6)
558 }
559
560 fn calculate_diversity(&self, samples: &[SampleResult]) -> f64 {
562 if samples.len() < 2 {
563 return 0.0;
564 }
565
566 let mut total_distance = 0.0;
567 let mut count = 0;
568
569 for i in 0..samples.len() {
570 for j in i + 1..samples.len() {
571 let distance = self.hamming_distance(&samples[i], &samples[j]);
572 total_distance += distance as f64;
573 count += 1;
574 }
575 }
576
577 if count > 0 {
578 total_distance / count as f64
579 } else {
580 0.0
581 }
582 }
583
584 fn hamming_distance(&self, a: &SampleResult, b: &SampleResult) -> usize {
586 a.assignments
587 .iter()
588 .filter(|(var, &val_a)| b.assignments.get(*var).copied().unwrap_or(false) != val_a)
589 .count()
590 }
591
592 fn calculate_temperature(&self, iteration: usize, max_iterations: usize) -> f64 {
594 let progress = iteration as f64 / max_iterations as f64;
595 let initial_temp = 10.0f64;
596 let final_temp = 0.01f64;
597
598 initial_temp * (final_temp / initial_temp).powf(progress)
599 }
600
601 fn evaluate_individual_fitness(
603 &self,
604 _individual: &Individual,
605 metrics: &PerformanceMetrics,
606 ) -> Result<f64, Box<dyn std::error::Error>> {
607 let objective_score = 1.0 / (1.0 + metrics.best_energy.abs());
609 let constraint_score = metrics.feasibility_rate;
610
611 Ok(0.7f64.mul_add(objective_score, 0.3 * constraint_score))
612 }
613
614 fn perturb_individual(
616 &self,
617 individual: &mut Individual,
618 ) -> Result<(), Box<dyn std::error::Error>> {
619 use scirs2_core::random::prelude::*;
620 let mut rng = thread_rng();
621
622 for value in individual.parameters.values_mut() {
623 if rng.gen::<f64>() < 0.3 {
624 let perturbation = rng.gen_range(-0.3..0.3) * value.abs();
625 *value += perturbation;
626 }
627 }
628
629 Ok(())
630 }
631
632 pub fn export_history(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
634 let export = AdaptiveExport {
635 config: self.config.clone(),
636 parameter_history: self.parameter_history.clone(),
637 performance_history: self.performance_history.clone(),
638 timestamp: std::time::SystemTime::now(),
639 };
640
641 let json = serde_json::to_string_pretty(&export)?;
642 std::fs::write(path, json)?;
643
644 Ok(())
645 }
646}
647
648#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct AdaptiveExport {
651 pub config: AdaptiveConfig,
652 pub parameter_history: Vec<ParameterState>,
653 pub performance_history: Vec<PerformanceMetrics>,
654 pub timestamp: std::time::SystemTime,
655}
656
657trait SamplerExt {
659 fn set_parameters(&mut self, params: HashMap<String, f64>);
660}
661
662impl<S: Sampler> SamplerExt for S {
663 fn set_parameters(&mut self, _params: HashMap<String, f64>) {
664 }
667}