1use crate::sampler::{SampleResult, Sampler};
7use scirs2_core::ndarray::Array2;
8use scirs2_core::random::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::time::{Duration, Instant};
12
13pub struct AdaptiveOptimizer {
15 samplers: Vec<(String, Box<dyn Sampler>)>,
17 performance_history: PerformanceHistory,
19 problem_analyzer: ProblemAnalyzer,
21 strategy_selector: StrategySelector,
23 #[allow(dead_code)]
25 parameter_tuner: ParameterTuner,
26 learning_rate: f64,
28 #[allow(dead_code)]
30 exploration_rate: f64,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PerformanceHistory {
35 feature_performance: HashMap<ProblemFeatures, AlgorithmPerformance>,
37 recent_runs: VecDeque<RunRecord>,
39 best_solutions: HashMap<String, BestSolution>,
41}
42
43#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ProblemFeatures {
45 size_category: SizeCategory,
47 density_category: DensityCategory,
49 structure_type: StructureType,
51 constraint_complexity: ConstraintComplexity,
53}
54
55#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
56pub enum SizeCategory {
57 Tiny, Small, Medium, Large, VeryLarge, }
63
64#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
65pub enum DensityCategory {
66 Sparse, Medium, Dense, }
70
71#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
72pub enum StructureType {
73 Random,
74 Regular,
75 Hierarchical,
76 Modular,
77 Unknown,
78}
79
80#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
81pub enum ConstraintComplexity {
82 None,
83 Simple,
84 Moderate,
85 Complex,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct AlgorithmPerformance {
90 success_rate: f64,
92 avg_quality: f64,
94 avg_time_ms: f64,
96 n_runs: usize,
98 best_params: HashMap<String, f64>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct RunRecord {
104 problem_id: String,
106 algorithm: String,
108 parameters: HashMap<String, f64>,
110 quality: f64,
112 time_ms: f64,
114 success: bool,
116 features: ProblemFeatures,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct BestSolution {
122 problem_id: String,
124 best_energy: f64,
126 algorithm: String,
128 time_to_solution: f64,
130 solution: HashMap<String, bool>,
132}
133
134pub struct ProblemAnalyzer {
136 extractors: Vec<Box<dyn FeatureExtractor>>,
138}
139
140trait FeatureExtractor: Send + Sync {
141 fn extract(&self, qubo: &Array2<f64>) -> HashMap<String, f64>;
142}
143
144pub struct StrategySelector {
146 strategy: SelectionStrategy,
148 #[allow(dead_code)]
150 performance_threshold: f64,
151}
152
153#[derive(Debug, Clone)]
154pub enum SelectionStrategy {
155 ThompsonSampling,
157 UCB { c: f64 },
159 EpsilonGreedy { epsilon: f64 },
161 Adaptive,
163}
164
165pub struct ParameterTuner {
167 #[allow(dead_code)]
169 param_ranges: HashMap<String, (f64, f64)>,
170 #[allow(dead_code)]
172 tuning_method: TuningMethod,
173 #[allow(dead_code)]
175 tuning_history: HashMap<String, Vec<(HashMap<String, f64>, f64)>>,
176}
177
178#[derive(Debug, Clone)]
179pub enum TuningMethod {
180 Grid { resolution: usize },
182 Random { n_trials: usize },
184 Bayesian,
186 Evolutionary { population_size: usize },
188}
189
190impl Default for AdaptiveOptimizer {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl AdaptiveOptimizer {
197 pub fn new() -> Self {
199 Self {
200 samplers: vec![
201 (
202 "SA".to_string(),
203 Box::new(crate::sampler::SASampler::new(None)),
204 ),
205 (
206 "GA".to_string(),
207 Box::new(crate::sampler::GASampler::new(None)),
208 ),
209 ],
210 performance_history: PerformanceHistory {
211 feature_performance: HashMap::new(),
212 recent_runs: VecDeque::with_capacity(1000),
213 best_solutions: HashMap::new(),
214 },
215 problem_analyzer: ProblemAnalyzer::new(),
216 strategy_selector: StrategySelector::new(SelectionStrategy::Adaptive),
217 parameter_tuner: ParameterTuner::new(TuningMethod::Bayesian),
218 learning_rate: 0.1,
219 exploration_rate: 0.2,
220 }
221 }
222
223 pub fn add_sampler(&mut self, name: String, sampler: Box<dyn Sampler>) {
225 self.samplers.push((name, sampler));
226 }
227
228 #[must_use]
230 pub const fn with_learning_rate(mut self, rate: f64) -> Self {
231 self.learning_rate = rate;
232 self
233 }
234
235 pub fn optimize(
237 &mut self,
238 qubo: &Array2<f64>,
239 var_map: &HashMap<String, usize>,
240 time_limit: Duration,
241 ) -> Result<OptimizationResult, String> {
242 let start_time = Instant::now();
243 let problem_id = self.generate_problem_id(qubo);
244
245 let features = self.problem_analyzer.analyze(qubo);
247
248 let (algorithm, parameters) = self.select_algorithm_and_params(&features)?;
250
251 let sampler = self.configure_sampler(&algorithm, ¶meters)?;
253
254 let mut best_result: Option<SampleResult> = None;
256 let mut iterations = 0;
257 let mut improvement_history = Vec::new();
258
259 while start_time.elapsed() < time_limit {
260 iterations += 1;
261
262 let shots = self.calculate_shot_count(&features, start_time.elapsed(), time_limit);
264
265 match sampler.run_qubo(&(qubo.clone(), var_map.clone()), shots) {
267 Ok(results) => {
268 for result in results {
269 let should_update = best_result
270 .as_ref()
271 .map_or(true, |best| result.energy < best.energy);
272 if should_update {
273 improvement_history
274 .push((start_time.elapsed().as_secs_f64(), result.energy));
275 best_result = Some(result);
276 }
277 }
278 }
279 Err(e) => {
280 self.record_run(RunRecord {
282 problem_id,
283 algorithm: algorithm.clone(),
284 parameters: parameters.clone(),
285 quality: f64::INFINITY,
286 time_ms: start_time.elapsed().as_millis() as f64,
287 success: false,
288 features,
289 });
290
291 return Err(format!("Sampler error: {e:?}"));
292 }
293 }
294
295 if self.check_convergence(&improvement_history) {
297 break;
298 }
299
300 if iterations % 10 == 0 {
302 self.adjust_parameters(&mut parameters.clone(), &improvement_history);
303 }
304 }
305
306 let total_time = start_time.elapsed();
307
308 if let Some(best) = best_result {
309 self.record_run(RunRecord {
311 problem_id: problem_id.clone(),
312 algorithm: algorithm.clone(),
313 parameters,
314 quality: best.energy,
315 time_ms: total_time.as_millis() as f64,
316 success: true,
317 features: features.clone(),
318 });
319
320 self.update_best_solution(problem_id, &best, &algorithm, total_time.as_secs_f64());
322
323 Ok(OptimizationResult {
324 best_solution: best.assignments,
325 best_energy: best.energy,
326 algorithm_used: algorithm,
327 time_taken: total_time,
328 iterations,
329 improvement_history,
330 features,
331 })
332 } else {
333 Err("No solution found".to_string())
334 }
335 }
336
337 fn generate_problem_id(&self, qubo: &Array2<f64>) -> String {
339 use std::collections::hash_map::DefaultHasher;
340 use std::hash::{Hash, Hasher};
341
342 let mut hasher = DefaultHasher::new();
343 qubo.shape().hash(&mut hasher);
344
345 let n = qubo.shape()[0];
347 for i in (0..n).step_by((n / 10).max(1)) {
348 for j in (0..n).step_by((n / 10).max(1)) {
349 (qubo[[i, j]].to_bits()).hash(&mut hasher);
350 }
351 }
352
353 format!("prob_{:x}", hasher.finish())
354 }
355
356 fn select_algorithm_and_params(
358 &self,
359 features: &ProblemFeatures,
360 ) -> Result<(String, HashMap<String, f64>), String> {
361 let perf = self.performance_history.feature_performance.get(features);
363
364 match &self.strategy_selector.strategy {
365 SelectionStrategy::Adaptive => {
366 if let Some(perf) = perf {
367 if perf.n_runs > 10 && perf.success_rate > 0.8 {
368 let algorithm = self.get_best_algorithm_for_features(features);
370 let params = perf.best_params.clone();
371 Ok((algorithm, params))
372 } else {
373 self.explore_new_algorithm(features)
375 }
376 } else {
377 self.select_by_heuristics(features)
379 }
380 }
381 SelectionStrategy::ThompsonSampling => self.thompson_sampling_select(features),
382 SelectionStrategy::UCB { c } => self.ucb_select(features, *c),
383 SelectionStrategy::EpsilonGreedy { epsilon } => {
384 if thread_rng().gen::<f64>() < *epsilon {
385 self.random_select()
386 } else {
387 self.greedy_select(features)
388 }
389 }
390 }
391 }
392
393 fn select_by_heuristics(
395 &self,
396 features: &ProblemFeatures,
397 ) -> Result<(String, HashMap<String, f64>), String> {
398 let algorithm = match (&features.size_category, &features.density_category) {
399 (SizeCategory::Tiny | SizeCategory::Small, _) => "SA",
400 (_, DensityCategory::Sparse) => "GA",
401 (SizeCategory::Medium, DensityCategory::Medium) => "SA",
402 _ => "GA",
403 };
404
405 let params = self.get_default_params(algorithm);
406 Ok((algorithm.to_string(), params))
407 }
408
409 fn get_default_params(&self, algorithm: &str) -> HashMap<String, f64> {
411 let mut params = HashMap::new();
412
413 match algorithm {
414 "SA" => {
415 params.insert("beta_min".to_string(), 0.1);
416 params.insert("beta_max".to_string(), 10.0);
417 params.insert("sweeps".to_string(), 1000.0);
418 }
419 "GA" => {
420 params.insert("population_size".to_string(), 100.0);
421 params.insert("elite_fraction".to_string(), 0.1);
422 params.insert("mutation_rate".to_string(), 0.01);
423 }
424 _ => {}
425 }
426
427 params
428 }
429
430 fn configure_sampler(
432 &self,
433 algorithm: &str,
434 parameters: &HashMap<String, f64>,
435 ) -> Result<Box<dyn Sampler>, String> {
436 match algorithm {
437 "SA" => {
438 let mut sampler = crate::sampler::SASampler::new(None);
439
440 if let Some(&beta_min) = parameters.get("beta_min") {
441 if let Some(&beta_max) = parameters.get("beta_max") {
442 sampler = sampler.with_beta_range(beta_min, beta_max);
443 }
444 }
445
446 if let Some(&sweeps) = parameters.get("sweeps") {
447 sampler = sampler.with_sweeps(sweeps as usize);
448 }
449
450 Ok(Box::new(sampler))
451 }
452 "GA" => {
453 let mut sampler = crate::sampler::GASampler::new(None);
454
455 if let Some(&pop_size) = parameters.get("population_size") {
456 sampler = sampler.with_population_size(pop_size as usize);
457 }
458
459 if let Some(&elite) = parameters.get("elite_fraction") {
460 sampler = sampler.with_elite_fraction(elite);
461 }
462
463 if let Some(&mutation) = parameters.get("mutation_rate") {
464 sampler = sampler.with_mutation_rate(mutation);
465 }
466
467 Ok(Box::new(sampler))
468 }
469 _ => Err(format!("Unknown algorithm: {algorithm}")),
470 }
471 }
472
473 fn calculate_shot_count(
475 &self,
476 features: &ProblemFeatures,
477 elapsed: Duration,
478 time_limit: Duration,
479 ) -> usize {
480 let remaining_fraction = 1.0 - (elapsed.as_secs_f64() / time_limit.as_secs_f64());
481
482 let base_shots = match features.size_category {
483 SizeCategory::Tiny => 10,
484 SizeCategory::Small => 50,
485 SizeCategory::Medium => 100,
486 SizeCategory::Large => 200,
487 SizeCategory::VeryLarge => 500,
488 };
489
490 ((base_shots as f64) * remaining_fraction.sqrt()) as usize
491 }
492
493 fn check_convergence(&self, history: &[(f64, f64)]) -> bool {
495 if history.len() < 10 {
496 return false;
497 }
498
499 let recent = &history[history.len() - 10..];
501 let best_recent = recent
502 .iter()
503 .map(|(_, e)| *e)
504 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
505 let best_overall = history
506 .iter()
507 .map(|(_, e)| *e)
508 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
509
510 match (best_recent, best_overall) {
511 (Some(recent), Some(overall)) => (recent - overall).abs() < 1e-6,
512 _ => false,
513 }
514 }
515
516 fn adjust_parameters(&mut self, params: &mut HashMap<String, f64>, history: &[(f64, f64)]) {
518 if history.len() < 2 {
519 return;
520 }
521
522 let Some(last_entry) = history.last() else {
525 return;
526 };
527 let recent_improvement = last_entry.1 - history[history.len() - 2].1;
528
529 if recent_improvement < 0.0 {
530 for (_, value) in params.iter_mut() {
532 *value *= 1.0 + self.learning_rate;
533 }
534 } else {
535 for (_, value) in params.iter_mut() {
537 *value *= self.learning_rate.mul_add(-0.5, 1.0);
538 }
539 }
540 }
541
542 fn record_run(&mut self, record: RunRecord) {
544 self.performance_history
546 .recent_runs
547 .push_back(record.clone());
548 if self.performance_history.recent_runs.len() > 1000 {
549 self.performance_history.recent_runs.pop_front();
550 }
551
552 let perf = self
554 .performance_history
555 .feature_performance
556 .entry(record.features.clone())
557 .or_insert_with(|| AlgorithmPerformance {
558 success_rate: 0.0,
559 avg_quality: 0.0,
560 avg_time_ms: 0.0,
561 n_runs: 0,
562 best_params: HashMap::new(),
563 });
564
565 let n = perf.n_runs as f64;
567 perf.avg_quality = perf.avg_quality.mul_add(n, record.quality) / (n + 1.0);
568 perf.avg_time_ms = perf.avg_time_ms.mul_add(n, record.time_ms) / (n + 1.0);
569 perf.success_rate = perf
570 .success_rate
571 .mul_add(n, if record.success { 1.0 } else { 0.0 })
572 / (n + 1.0);
573 perf.n_runs += 1;
574
575 if record.success && (perf.best_params.is_empty() || record.quality < perf.avg_quality) {
577 perf.best_params = record.parameters;
578 }
579 }
580
581 fn update_best_solution(
583 &mut self,
584 problem_id: String,
585 result: &SampleResult,
586 algorithm: &str,
587 time: f64,
588 ) {
589 let entry = self
590 .performance_history
591 .best_solutions
592 .entry(problem_id.clone())
593 .or_insert_with(|| BestSolution {
594 problem_id,
595 best_energy: f64::INFINITY,
596 algorithm: String::new(),
597 time_to_solution: 0.0,
598 solution: HashMap::new(),
599 });
600
601 if result.energy < entry.best_energy {
602 entry.best_energy = result.energy;
603 entry.algorithm = algorithm.to_string();
604 entry.time_to_solution = time;
605 entry.solution = result.assignments.clone();
606 }
607 }
608
609 fn get_best_algorithm_for_features(&self, _features: &ProblemFeatures) -> String {
612 "SA".to_string()
614 }
615
616 fn explore_new_algorithm(
617 &self,
618 _features: &ProblemFeatures,
619 ) -> Result<(String, HashMap<String, f64>), String> {
620 let idx = thread_rng().gen_range(0..self.samplers.len());
621 let algorithm = self.samplers[idx].0.clone();
622 let params = self.get_default_params(&algorithm);
623 Ok((algorithm, params))
624 }
625
626 fn thompson_sampling_select(
627 &self,
628 _features: &ProblemFeatures,
629 ) -> Result<(String, HashMap<String, f64>), String> {
630 self.random_select()
632 }
633
634 fn ucb_select(
635 &self,
636 features: &ProblemFeatures,
637 _c: f64,
638 ) -> Result<(String, HashMap<String, f64>), String> {
639 self.select_by_heuristics(features)
641 }
642
643 fn random_select(&self) -> Result<(String, HashMap<String, f64>), String> {
644 let idx = thread_rng().gen_range(0..self.samplers.len());
645 let algorithm = self.samplers[idx].0.clone();
646 let params = self.get_default_params(&algorithm);
647 Ok((algorithm, params))
648 }
649
650 fn greedy_select(
651 &self,
652 features: &ProblemFeatures,
653 ) -> Result<(String, HashMap<String, f64>), String> {
654 if let Some(perf) = self.performance_history.feature_performance.get(features) {
655 let algorithm = self.get_best_algorithm_for_features(features);
656 Ok((algorithm, perf.best_params.clone()))
657 } else {
658 self.select_by_heuristics(features)
659 }
660 }
661
662 pub fn save_history(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
664 let json = serde_json::to_string_pretty(&self.performance_history)?;
665 std::fs::write(path, json)?;
666 Ok(())
667 }
668
669 pub fn load_history(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
671 let json = std::fs::read_to_string(path)?;
672 self.performance_history = serde_json::from_str(&json)?;
673 Ok(())
674 }
675}
676
677impl ProblemAnalyzer {
678 fn new() -> Self {
680 Self {
681 extractors: vec![
682 Box::new(BasicFeatureExtractor),
683 Box::new(StructureFeatureExtractor),
684 ],
685 }
686 }
687
688 fn analyze(&self, qubo: &Array2<f64>) -> ProblemFeatures {
690 let n = qubo.shape()[0];
691
692 let mut all_features = HashMap::new();
694 for extractor in &self.extractors {
695 all_features.extend(extractor.extract(qubo));
696 }
697
698 let size_category = match n {
700 0..=9 => SizeCategory::Tiny,
701 10..=49 => SizeCategory::Small,
702 50..=199 => SizeCategory::Medium,
703 200..=999 => SizeCategory::Large,
704 _ => SizeCategory::VeryLarge,
705 };
706
707 let density = all_features.get("density").copied().unwrap_or(0.5);
708 let density_category = match density {
709 d if d < 0.1 => DensityCategory::Sparse,
710 d if d < 0.5 => DensityCategory::Medium,
711 _ => DensityCategory::Dense,
712 };
713
714 let structure_score = all_features.get("structure_score").copied().unwrap_or(0.0);
715 let structure_type = if structure_score < 0.2 {
716 StructureType::Random
717 } else if structure_score < 0.5 {
718 StructureType::Regular
719 } else {
720 StructureType::Hierarchical
721 };
722
723 ProblemFeatures {
724 size_category,
725 density_category,
726 structure_type,
727 constraint_complexity: ConstraintComplexity::None, }
729 }
730}
731
732struct BasicFeatureExtractor;
733
734impl FeatureExtractor for BasicFeatureExtractor {
735 fn extract(&self, qubo: &Array2<f64>) -> HashMap<String, f64> {
736 let n = qubo.shape()[0];
737 let mut features = HashMap::new();
738
739 features.insert("size".to_string(), n as f64);
741
742 let non_zeros = qubo.iter().filter(|&&x| x.abs() > 1e-10).count();
744 features.insert("density".to_string(), non_zeros as f64 / (n * n) as f64);
745
746 let values: Vec<f64> = qubo.iter().copied().collect();
748 features.insert(
749 "mean".to_string(),
750 values.iter().sum::<f64>() / values.len() as f64,
751 );
752 features.insert(
753 "max".to_string(),
754 values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
755 );
756 features.insert(
757 "min".to_string(),
758 values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
759 );
760
761 features
762 }
763}
764
765struct StructureFeatureExtractor;
766
767impl FeatureExtractor for StructureFeatureExtractor {
768 fn extract(&self, qubo: &Array2<f64>) -> HashMap<String, f64> {
769 let mut features = HashMap::new();
770 let n = qubo.shape()[0];
771
772 let mut diag_sum = 0.0;
774 let mut total_sum = 0.0;
775 for i in 0..n {
776 diag_sum += qubo[[i, i]].abs();
777 for j in 0..n {
778 total_sum += qubo[[i, j]].abs();
779 }
780 }
781
782 features.insert(
783 "diagonal_dominance".to_string(),
784 if total_sum > 0.0 {
785 diag_sum / total_sum
786 } else {
787 0.0
788 },
789 );
790
791 let mut symmetry_score = 0.0;
793 for i in 0..n {
794 for j in i + 1..n {
795 let diff = (qubo[[i, j]] - qubo[[j, i]]).abs();
796 let avg = f64::midpoint(qubo[[i, j]].abs(), qubo[[j, i]].abs());
797 if avg > 1e-10 {
798 symmetry_score += 1.0 - diff / avg;
799 }
800 }
801 }
802 features.insert(
803 "symmetry".to_string(),
804 symmetry_score / ((n * (n - 1)) / 2) as f64,
805 );
806
807 let diagonal_dominance = features.get("diagonal_dominance").copied().unwrap_or(0.0);
810 let symmetry = features.get("symmetry").copied().unwrap_or(0.0);
811 let structure_score = diagonal_dominance * 0.5 + symmetry * 0.5;
812 features.insert("structure_score".to_string(), structure_score);
813
814 features
815 }
816}
817
818impl StrategySelector {
819 const fn new(strategy: SelectionStrategy) -> Self {
820 Self {
821 strategy,
822 performance_threshold: 0.8,
823 }
824 }
825}
826
827impl ParameterTuner {
828 fn new(method: TuningMethod) -> Self {
829 Self {
830 param_ranges: HashMap::new(),
831 tuning_method: method,
832 tuning_history: HashMap::new(),
833 }
834 }
835}
836
837#[derive(Debug, Clone)]
839pub struct OptimizationResult {
840 pub best_solution: HashMap<String, bool>,
842 pub best_energy: f64,
844 pub algorithm_used: String,
846 pub time_taken: Duration,
848 pub iterations: usize,
850 pub improvement_history: Vec<(f64, f64)>,
852 pub features: ProblemFeatures,
854}
855
856#[cfg(test)]
857mod tests {
858 use super::*;
859 use scirs2_core::ndarray::array;
860
861 #[test]
862 fn test_adaptive_optimizer() {
863 let mut optimizer = AdaptiveOptimizer::new();
864
865 let mut qubo = array![[0.0, -1.0, 0.5], [-1.0, 0.0, -0.5], [0.5, -0.5, 0.0]];
866
867 let mut var_map = HashMap::new();
868 var_map.insert("x".to_string(), 0);
869 var_map.insert("y".to_string(), 1);
870 var_map.insert("z".to_string(), 2);
871
872 let mut result = optimizer
873 .optimize(&qubo, &var_map, Duration::from_secs(1))
874 .expect("Optimization should succeed for valid QUBO");
875
876 assert!(!result.best_solution.is_empty());
877 assert!(result.best_energy < 0.0);
878 assert!(!result.improvement_history.is_empty());
879 }
880
881 #[test]
882 #[ignore]
883 fn test_problem_analyzer() {
884 let analyzer = ProblemAnalyzer::new();
885
886 let small_sparse = Array2::eye(10);
887 let features = analyzer.analyze(&small_sparse);
888 assert_eq!(features.size_category, SizeCategory::Small);
889 assert_eq!(features.density_category, DensityCategory::Sparse);
890
891 let large_dense = Array2::ones((500, 500));
892 let features = analyzer.analyze(&large_dense);
893 assert_eq!(features.size_category, SizeCategory::Large);
894 assert_eq!(features.density_category, DensityCategory::Dense);
895 }
896}