quantrs2_tytan/optimization/
adaptive.rs

1//! Adaptive optimization strategies for quantum annealing
2//!
3//! This module provides adaptive algorithms that adjust parameters
4//! during optimization based on performance feedback.
5
6use crate::{
7    optimization::penalty::CompiledModel,
8    sampler::{SampleResult, Sampler},
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[cfg(feature = "scirs")]
14use crate::scirs_stub::scirs2_core::statistics::{MovingAverage, OnlineStats};
15
16/// Adaptive strategy types
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum AdaptiveStrategy {
19    /// Exponential decay of penalty weights
20    ExponentialDecay,
21    /// Adaptive penalty method (APM)
22    AdaptivePenaltyMethod,
23    /// Augmented Lagrangian with multiplier updates
24    AugmentedLagrangian,
25    /// Population-based training
26    PopulationBased,
27    /// Multi-armed bandit for parameter selection
28    MultiArmedBandit,
29}
30
31/// Adaptive optimizer configuration
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AdaptiveConfig {
34    pub strategy: AdaptiveStrategy,
35    pub update_interval: usize,
36    pub learning_rate: f64,
37    pub momentum: f64,
38    pub patience: usize,
39    pub exploration_rate: f64,
40    pub population_size: usize,
41    pub history_window: usize,
42}
43
44impl Default for AdaptiveConfig {
45    fn default() -> Self {
46        Self {
47            strategy: AdaptiveStrategy::AdaptivePenaltyMethod,
48            update_interval: 10,
49            learning_rate: 0.1,
50            momentum: 0.9,
51            patience: 5,
52            exploration_rate: 0.1,
53            population_size: 10,
54            history_window: 100,
55        }
56    }
57}
58
59/// Adaptive optimizer
60pub struct AdaptiveOptimizer {
61    config: AdaptiveConfig,
62    iteration: usize,
63    parameter_history: Vec<ParameterState>,
64    performance_history: Vec<PerformanceMetrics>,
65    lagrange_multipliers: HashMap<String, f64>,
66    population: Vec<Individual>,
67    #[cfg(feature = "scirs")]
68    stats: OnlineStats,
69}
70
71/// Parameter state at a given iteration
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ParameterState {
74    pub iteration: usize,
75    pub parameters: HashMap<String, f64>,
76    pub penalty_weights: HashMap<String, f64>,
77    pub temperature: f64,
78}
79
80/// Performance metrics
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct PerformanceMetrics {
83    pub iteration: usize,
84    pub best_energy: f64,
85    pub avg_energy: f64,
86    pub constraint_violations: HashMap<String, f64>,
87    pub feasibility_rate: f64,
88    pub diversity: f64,
89}
90
91/// Individual in population-based methods
92#[derive(Debug, Clone)]
93struct Individual {
94    id: usize,
95    parameters: HashMap<String, f64>,
96    fitness: f64,
97    constraint_satisfaction: f64,
98}
99
100/// Adaptive optimization result
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct AdaptiveResult {
103    pub final_parameters: HashMap<String, f64>,
104    pub final_penalty_weights: HashMap<String, f64>,
105    pub convergence_history: Vec<f64>,
106    pub constraint_history: Vec<HashMap<String, f64>>,
107    pub total_iterations: usize,
108    pub best_solution: AdaptiveSampleResult,
109}
110
111/// Sample result wrapper for serialization
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AdaptiveSampleResult {
114    pub assignments: HashMap<String, bool>,
115    pub energy: f64,
116}
117
118impl AdaptiveOptimizer {
119    /// Create new adaptive optimizer
120    pub fn new(config: AdaptiveConfig) -> Self {
121        Self {
122            config,
123            iteration: 0,
124            parameter_history: Vec::new(),
125            performance_history: Vec::new(),
126            lagrange_multipliers: HashMap::new(),
127            population: Vec::new(),
128            #[cfg(feature = "scirs")]
129            stats: OnlineStats::new(),
130        }
131    }
132
133    /// Run adaptive optimization
134    pub fn optimize<S: Sampler + Clone>(
135        &mut self,
136        mut sampler: S,
137        model: &CompiledModel,
138        initial_params: HashMap<String, f64>,
139        initial_penalties: HashMap<String, f64>,
140        max_iterations: usize,
141    ) -> Result<AdaptiveResult, Box<dyn std::error::Error>> {
142        // Initialize
143        let mut current_params = initial_params;
144        let mut penalty_weights = initial_penalties;
145        let mut best_solution = None;
146        let mut best_energy = f64::INFINITY;
147
148        // Initialize strategy-specific components
149        match self.config.strategy {
150            AdaptiveStrategy::PopulationBased => {
151                self.initialize_population(&current_params)?;
152            }
153            AdaptiveStrategy::AugmentedLagrangian => {
154                self.initialize_lagrange_multipliers(&penalty_weights);
155            }
156            _ => {}
157        }
158
159        // Main optimization loop
160        let mut no_improvement_count = 0;
161
162        for iter in 0..max_iterations {
163            self.iteration = iter;
164
165            // Run sampling with current parameters
166            let samples =
167                self.run_sampling(&mut sampler, model, &current_params, &penalty_weights)?;
168
169            // Evaluate performance
170            let metrics = self.evaluate_performance(model, &samples)?;
171            self.performance_history.push(metrics.clone());
172
173            // Update best solution
174            if let Some(sample) = samples.iter().min_by(|a, b| {
175                a.energy
176                    .partial_cmp(&b.energy)
177                    .unwrap_or(std::cmp::Ordering::Equal)
178            }) {
179                if sample.energy < best_energy {
180                    best_energy = sample.energy;
181                    best_solution = Some(AdaptiveSampleResult {
182                        assignments: sample.assignments.clone(),
183                        energy: sample.energy,
184                    });
185                    no_improvement_count = 0;
186                } else {
187                    no_improvement_count += 1;
188                }
189            }
190
191            // Check early stopping
192            if no_improvement_count > self.config.patience {
193                break;
194            }
195
196            // Update parameters based on strategy
197            if iter % self.config.update_interval == 0 && iter > 0 {
198                self.update_parameters(&mut current_params, &mut penalty_weights, &metrics)?;
199            }
200
201            // Record state
202            self.parameter_history.push(ParameterState {
203                iteration: iter,
204                parameters: current_params.clone(),
205                penalty_weights: penalty_weights.clone(),
206                temperature: self.calculate_temperature(iter, max_iterations),
207            });
208        }
209
210        // Prepare result
211        let convergence_history = self
212            .performance_history
213            .iter()
214            .map(|m| m.best_energy)
215            .collect();
216
217        let constraint_history = self
218            .performance_history
219            .iter()
220            .map(|m| m.constraint_violations.clone())
221            .collect();
222
223        Ok(AdaptiveResult {
224            final_parameters: current_params,
225            final_penalty_weights: penalty_weights,
226            convergence_history,
227            constraint_history,
228            total_iterations: self.iteration,
229            best_solution: best_solution.ok_or("No valid solution found")?,
230        })
231    }
232
233    /// Run sampling with current parameters
234    fn run_sampling<S: Sampler>(
235        &self,
236        sampler: &mut S,
237        model: &CompiledModel,
238        params: &HashMap<String, f64>,
239        penalty_weights: &HashMap<String, f64>,
240    ) -> Result<Vec<SampleResult>, Box<dyn std::error::Error>> {
241        // Apply penalty weights to model
242        let penalized_model = self.apply_penalties(model, penalty_weights)?;
243
244        // Configure sampler with parameters
245        sampler.set_parameters(params.clone());
246
247        // Run sampling
248        let num_reads = params.get("num_reads").copied().unwrap_or(100.0) as usize;
249
250        Ok(sampler.run_qubo(&penalized_model.to_qubo(), num_reads)?)
251    }
252
253    /// Evaluate performance metrics
254    fn evaluate_performance(
255        &self,
256        model: &CompiledModel,
257        samples: &[SampleResult],
258    ) -> Result<PerformanceMetrics, Box<dyn std::error::Error>> {
259        let energies: Vec<f64> = samples.iter().map(|s| s.energy).collect();
260        let best_energy = energies.iter().fold(f64::INFINITY, |a, &b| a.min(b));
261        let avg_energy = energies.iter().sum::<f64>() / energies.len() as f64;
262
263        // Evaluate constraint violations
264        let constraint_violations = self.evaluate_constraint_violations(model, samples)?;
265
266        // Calculate feasibility rate
267        let feasible_count = samples
268            .iter()
269            .filter(|s| self.is_feasible(s, &constraint_violations).unwrap_or(false))
270            .count();
271        let feasibility_rate = feasible_count as f64 / samples.len() as f64;
272
273        // Calculate diversity
274        let diversity = self.calculate_diversity(samples);
275
276        Ok(PerformanceMetrics {
277            iteration: self.iteration,
278            best_energy,
279            avg_energy,
280            constraint_violations,
281            feasibility_rate,
282            diversity,
283        })
284    }
285
286    /// Update parameters based on adaptive strategy
287    fn update_parameters(
288        &mut self,
289        params: &mut HashMap<String, f64>,
290        penalty_weights: &mut HashMap<String, f64>,
291        metrics: &PerformanceMetrics,
292    ) -> Result<(), Box<dyn std::error::Error>> {
293        match self.config.strategy {
294            AdaptiveStrategy::ExponentialDecay => {
295                self.update_exponential_decay(params, penalty_weights)?;
296            }
297            AdaptiveStrategy::AdaptivePenaltyMethod => {
298                self.update_adaptive_penalty(penalty_weights, metrics)?;
299            }
300            AdaptiveStrategy::AugmentedLagrangian => {
301                self.update_augmented_lagrangian(penalty_weights, metrics)?;
302            }
303            AdaptiveStrategy::PopulationBased => {
304                self.update_population_based(params, penalty_weights, metrics)?;
305            }
306            AdaptiveStrategy::MultiArmedBandit => {
307                self.update_multi_armed_bandit(params, metrics)?;
308            }
309        }
310
311        Ok(())
312    }
313
314    /// Exponential decay update
315    fn update_exponential_decay(
316        &self,
317        params: &mut HashMap<String, f64>,
318        penalty_weights: &mut HashMap<String, f64>,
319    ) -> Result<(), Box<dyn std::error::Error>> {
320        let decay_rate = 0.95;
321
322        // Decay temperature parameter
323        if let Some(temp) = params.get_mut("temperature") {
324            *temp *= decay_rate;
325        }
326
327        // Optionally adjust penalty weights
328        for weight in penalty_weights.values_mut() {
329            *weight *= 1.0 / decay_rate.sqrt(); // Increase penalties as temperature decreases
330        }
331
332        Ok(())
333    }
334
335    /// Adaptive penalty method update
336    fn update_adaptive_penalty(
337        &mut self,
338        penalty_weights: &mut HashMap<String, f64>,
339        metrics: &PerformanceMetrics,
340    ) -> Result<(), Box<dyn std::error::Error>> {
341        // Update penalties based on constraint violations
342        for (constraint_name, &violation) in &metrics.constraint_violations {
343            if let Some(weight) = penalty_weights.get_mut(constraint_name) {
344                if violation > 1e-6 {
345                    // Increase penalty
346                    *weight *= 1.0 + self.config.learning_rate;
347                } else {
348                    // Decrease penalty if over-penalized
349                    *weight *= self.config.learning_rate.mul_add(-0.5, 1.0);
350                }
351
352                // Apply bounds
353                *weight = weight.clamp(0.001, 1000.0);
354            }
355        }
356
357        #[cfg(feature = "scirs")]
358        {
359            // Update statistics
360            self.stats.update(metrics.best_energy);
361        }
362
363        Ok(())
364    }
365
366    /// Augmented Lagrangian update
367    fn update_augmented_lagrangian(
368        &mut self,
369        penalty_weights: &mut HashMap<String, f64>,
370        metrics: &PerformanceMetrics,
371    ) -> Result<(), Box<dyn std::error::Error>> {
372        // Update Lagrange multipliers
373        for (constraint_name, &violation) in &metrics.constraint_violations {
374            let multiplier = self
375                .lagrange_multipliers
376                .entry(constraint_name.clone())
377                .or_insert(0.0);
378
379            // Gradient ascent on multipliers
380            *multiplier += self.config.learning_rate * violation;
381
382            // Update penalty weight (augmented term)
383            if let Some(weight) = penalty_weights.get_mut(constraint_name) {
384                *weight = 0.5f64.mul_add(weight.sqrt(), multiplier.abs());
385            }
386        }
387
388        Ok(())
389    }
390
391    /// Population-based update
392    fn update_population_based(
393        &mut self,
394        params: &mut HashMap<String, f64>,
395        _penalty_weights: &mut HashMap<String, f64>,
396        metrics: &PerformanceMetrics,
397    ) -> Result<(), Box<dyn std::error::Error>> {
398        // Evaluate population fitness
399        let fitness_values: Vec<f64> = self
400            .population
401            .iter()
402            .map(|individual| self.evaluate_individual_fitness(individual, metrics))
403            .collect::<Result<Vec<_>, _>>()?;
404
405        for (i, fitness) in fitness_values.into_iter().enumerate() {
406            self.population[i].fitness = fitness;
407        }
408
409        // Sort by fitness
410        self.population.sort_by(|a, b| {
411            b.fitness
412                .partial_cmp(&a.fitness)
413                .unwrap_or(std::cmp::Ordering::Equal)
414        });
415
416        // Exploit: copy parameters from best individuals
417        if let Some(best) = self.population.first() {
418            for (key, value) in &best.parameters {
419                if let Some(param) = params.get_mut(key) {
420                    *param = self
421                        .config
422                        .momentum
423                        .mul_add(*param, (1.0 - self.config.momentum) * value);
424                }
425            }
426        }
427
428        // Explore: perturb bottom half of population
429        let mid = self.population.len() / 2;
430        let pop_len = self.population.len();
431        for i in mid..pop_len {
432            // Use random perturbation directly to avoid borrow issues
433            use scirs2_core::random::prelude::*;
434            let mut rng = thread_rng();
435
436            for value in self.population[i].parameters.values_mut() {
437                if rng.gen::<f64>() < 0.3 {
438                    let perturbation = rng.gen_range(-0.3..0.3) * value.abs();
439                    *value += perturbation;
440                }
441            }
442        }
443
444        Ok(())
445    }
446
447    /// Multi-armed bandit update
448    fn update_multi_armed_bandit(
449        &mut self,
450        params: &mut HashMap<String, f64>,
451        _metrics: &PerformanceMetrics,
452    ) -> Result<(), Box<dyn std::error::Error>> {
453        // Implement UCB or Thompson sampling for parameter selection
454        // This is a simplified version
455
456        use scirs2_core::random::prelude::*;
457        let mut rng = thread_rng();
458
459        for (param_name, param_value) in params.iter_mut() {
460            if rng.gen::<f64>() < self.config.exploration_rate {
461                // Explore: random perturbation
462                let perturbation = rng.gen_range(-0.1..0.1) * param_value.abs();
463                *param_value += perturbation;
464            } else {
465                // Exploit: move toward historical best
466                if let Some(best_state) = self.parameter_history.iter().min_by(|a, b| {
467                    let a_metrics = &self.performance_history[a.iteration];
468                    let b_metrics = &self.performance_history[b.iteration];
469                    a_metrics
470                        .best_energy
471                        .partial_cmp(&b_metrics.best_energy)
472                        .unwrap_or(std::cmp::Ordering::Equal)
473                }) {
474                    if let Some(best_value) = best_state.parameters.get(param_name) {
475                        *param_value += self.config.learning_rate * (best_value - *param_value);
476                    }
477                }
478            }
479        }
480
481        Ok(())
482    }
483
484    /// Initialize population for population-based methods
485    fn initialize_population(
486        &mut self,
487        base_params: &HashMap<String, f64>,
488    ) -> Result<(), Box<dyn std::error::Error>> {
489        use scirs2_core::random::prelude::*;
490        let mut rng = thread_rng();
491
492        for i in 0..self.config.population_size {
493            let mut params = base_params.clone();
494
495            // Add random perturbations
496            for value in params.values_mut() {
497                let perturbation = rng.gen_range(-0.2..0.2) * value.abs();
498                *value += perturbation;
499            }
500
501            self.population.push(Individual {
502                id: i,
503                parameters: params,
504                fitness: 0.0,
505                constraint_satisfaction: 0.0,
506            });
507        }
508
509        Ok(())
510    }
511
512    /// Initialize Lagrange multipliers
513    fn initialize_lagrange_multipliers(&mut self, penalty_weights: &HashMap<String, f64>) {
514        for (constraint_name, &weight) in penalty_weights {
515            self.lagrange_multipliers
516                .insert(constraint_name.clone(), weight * 0.1);
517        }
518    }
519
520    /// Apply penalties to model
521    fn apply_penalties(
522        &self,
523        model: &CompiledModel,
524        _penalty_weights: &HashMap<String, f64>,
525    ) -> Result<CompiledModel, Box<dyn std::error::Error>> {
526        // This would modify the model's QUBO matrix with penalty terms
527        // For now, return the original model
528        Ok(model.clone())
529    }
530
531    /// Evaluate constraint violations
532    fn evaluate_constraint_violations(
533        &self,
534        model: &CompiledModel,
535        _samples: &[SampleResult],
536    ) -> Result<HashMap<String, f64>, Box<dyn std::error::Error>> {
537        // Placeholder implementation
538        let mut violations = HashMap::new();
539
540        for constraint_name in model.get_constraints().keys() {
541            violations.insert(constraint_name.clone(), 0.0);
542        }
543
544        Ok(violations)
545    }
546
547    /// Check if solution is feasible
548    fn is_feasible(
549        &self,
550        _sample: &SampleResult,
551        constraint_violations: &HashMap<String, f64>,
552    ) -> Result<bool, Box<dyn std::error::Error>> {
553        let max_violation = constraint_violations
554            .values()
555            .fold(0.0f64, |a, &b| a.max(b.abs()));
556
557        Ok(max_violation < 1e-6)
558    }
559
560    /// Calculate solution diversity
561    fn calculate_diversity(&self, samples: &[SampleResult]) -> f64 {
562        if samples.len() < 2 {
563            return 0.0;
564        }
565
566        let mut total_distance = 0.0;
567        let mut count = 0;
568
569        for i in 0..samples.len() {
570            for j in i + 1..samples.len() {
571                let distance = self.hamming_distance(&samples[i], &samples[j]);
572                total_distance += distance as f64;
573                count += 1;
574            }
575        }
576
577        if count > 0 {
578            total_distance / count as f64
579        } else {
580            0.0
581        }
582    }
583
584    /// Calculate Hamming distance between solutions
585    fn hamming_distance(&self, a: &SampleResult, b: &SampleResult) -> usize {
586        a.assignments
587            .iter()
588            .filter(|(var, &val_a)| b.assignments.get(*var).copied().unwrap_or(false) != val_a)
589            .count()
590    }
591
592    /// Calculate temperature for annealing schedule
593    fn calculate_temperature(&self, iteration: usize, max_iterations: usize) -> f64 {
594        let progress = iteration as f64 / max_iterations as f64;
595        let initial_temp = 10.0f64;
596        let final_temp = 0.01f64;
597
598        initial_temp * (final_temp / initial_temp).powf(progress)
599    }
600
601    /// Evaluate individual fitness in population
602    fn evaluate_individual_fitness(
603        &self,
604        _individual: &Individual,
605        metrics: &PerformanceMetrics,
606    ) -> Result<f64, Box<dyn std::error::Error>> {
607        // Combine objective value and constraint satisfaction
608        let objective_score = 1.0 / (1.0 + metrics.best_energy.abs());
609        let constraint_score = metrics.feasibility_rate;
610
611        Ok(0.7f64.mul_add(objective_score, 0.3 * constraint_score))
612    }
613
614    /// Perturb individual in population
615    fn perturb_individual(
616        &self,
617        individual: &mut Individual,
618    ) -> Result<(), Box<dyn std::error::Error>> {
619        use scirs2_core::random::prelude::*;
620        let mut rng = thread_rng();
621
622        for value in individual.parameters.values_mut() {
623            if rng.gen::<f64>() < 0.3 {
624                let perturbation = rng.gen_range(-0.3..0.3) * value.abs();
625                *value += perturbation;
626            }
627        }
628
629        Ok(())
630    }
631
632    /// Export optimization history
633    pub fn export_history(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
634        let export = AdaptiveExport {
635            config: self.config.clone(),
636            parameter_history: self.parameter_history.clone(),
637            performance_history: self.performance_history.clone(),
638            timestamp: std::time::SystemTime::now(),
639        };
640
641        let json = serde_json::to_string_pretty(&export)?;
642        std::fs::write(path, json)?;
643
644        Ok(())
645    }
646}
647
648/// Export format for adaptive optimization
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct AdaptiveExport {
651    pub config: AdaptiveConfig,
652    pub parameter_history: Vec<ParameterState>,
653    pub performance_history: Vec<PerformanceMetrics>,
654    pub timestamp: std::time::SystemTime,
655}
656
657// Helper trait for sampler parameter setting
658trait SamplerExt {
659    fn set_parameters(&mut self, params: HashMap<String, f64>);
660}
661
662impl<S: Sampler> SamplerExt for S {
663    fn set_parameters(&mut self, _params: HashMap<String, f64>) {
664        // This would be implemented by specific samplers
665        // For now, it's a no-op
666    }
667}