1use scirs2_core::ndarray::{Array, Dimension, Ix2};
4use scirs2_core::random::prelude::*;
5use scirs2_core::random::rngs::StdRng;
6use std::collections::HashMap;
7
8use super::{SampleResult, Sampler, SamplerResult};
9
10pub struct GASampler {
16 seed: Option<u64>,
18 max_generations: usize,
20 population_size: usize,
22}
23
24#[derive(Debug, Clone, Copy)]
26pub enum CrossoverStrategy {
27 Uniform,
29 SinglePoint,
31 TwoPoint,
33 Adaptive,
35}
36
37#[derive(Debug, Clone, Copy)]
39pub enum MutationStrategy {
40 FixedRate(f64),
42 Annealing(f64, f64), Adaptive(f64, f64), }
47
48impl GASampler {
49 #[must_use]
55 pub const fn new(seed: Option<u64>) -> Self {
56 Self {
57 seed,
58 max_generations: 1000,
59 population_size: 100,
60 }
61 }
62
63 #[must_use]
71 pub const fn with_params(
72 seed: Option<u64>,
73 max_generations: usize,
74 population_size: usize,
75 ) -> Self {
76 Self {
77 seed,
78 max_generations,
79 population_size,
80 }
81 }
82
83 pub const fn with_population_size(mut self, size: usize) -> Self {
85 self.population_size = size;
86 self
87 }
88
89 pub const fn with_elite_fraction(self, _fraction: f64) -> Self {
91 self
94 }
95
96 pub const fn with_mutation_rate(self, _rate: f64) -> Self {
98 self
101 }
102
103 pub const fn with_advanced_params(
113 seed: Option<u64>,
114 max_generations: usize,
115 population_size: usize,
116 _crossover: CrossoverStrategy, _mutation: MutationStrategy, ) -> Self {
119 Self {
120 seed,
121 max_generations,
122 population_size,
123 }
124 }
125
126 fn crossover(
128 &self,
129 parent1: &[bool],
130 parent2: &[bool],
131 strategy: CrossoverStrategy,
132 rng: &mut impl Rng,
133 ) -> (Vec<bool>, Vec<bool>) {
134 let n_vars = parent1.len();
135 let mut child1 = vec![false; n_vars];
136 let mut child2 = vec![false; n_vars];
137
138 match strategy {
139 CrossoverStrategy::Uniform => {
140 for i in 0..n_vars {
142 if rng.gen_bool(0.5) {
143 child1[i] = parent1[i];
144 child2[i] = parent2[i];
145 } else {
146 child1[i] = parent2[i];
147 child2[i] = parent1[i];
148 }
149 }
150 }
151 CrossoverStrategy::SinglePoint => {
152 let crossover_point = rng.gen_range(1..n_vars);
154
155 for i in 0..n_vars {
156 if i < crossover_point {
157 child1[i] = parent1[i];
158 child2[i] = parent2[i];
159 } else {
160 child1[i] = parent2[i];
161 child2[i] = parent1[i];
162 }
163 }
164 }
165 CrossoverStrategy::TwoPoint => {
166 let point1 = rng.gen_range(1..(n_vars - 1));
168 let point2 = rng.gen_range((point1 + 1)..n_vars);
169
170 for i in 0..n_vars {
171 if i < point1 || i >= point2 {
172 child1[i] = parent1[i];
173 child2[i] = parent2[i];
174 } else {
175 child1[i] = parent2[i];
176 child2[i] = parent1[i];
177 }
178 }
179 }
180 CrossoverStrategy::Adaptive => {
181 let mut hamming_distance = 0;
183 for i in 0..n_vars {
184 if parent1[i] != parent2[i] {
185 hamming_distance += 1;
186 }
187 }
188
189 let similarity = 1.0 - (hamming_distance as f64 / n_vars as f64);
191
192 if similarity > 0.8 {
193 for i in 0..n_vars {
195 if rng.gen_bool(0.5) {
196 child1[i] = parent1[i];
197 child2[i] = parent2[i];
198 } else {
199 child1[i] = parent2[i];
200 child2[i] = parent1[i];
201 }
202 }
203 } else if similarity > 0.4 {
204 let point1 = rng.gen_range(1..(n_vars - 1));
206 let point2 = rng.gen_range((point1 + 1)..n_vars);
207
208 for i in 0..n_vars {
209 if i < point1 || i >= point2 {
210 child1[i] = parent1[i];
211 child2[i] = parent2[i];
212 } else {
213 child1[i] = parent2[i];
214 child2[i] = parent1[i];
215 }
216 }
217 } else {
218 let crossover_point = rng.gen_range(1..n_vars);
220
221 for i in 0..n_vars {
222 if i < crossover_point {
223 child1[i] = parent1[i];
224 child2[i] = parent2[i];
225 } else {
226 child1[i] = parent2[i];
227 child2[i] = parent1[i];
228 }
229 }
230 }
231 }
232 }
233
234 (child1, child2)
235 }
236
237 fn mutate(
239 &self,
240 individual: &mut [bool],
241 strategy: MutationStrategy,
242 generation: usize,
243 max_generations: usize,
244 diversity: Option<f64>,
245 rng: &mut impl Rng,
246 ) {
247 match strategy {
248 MutationStrategy::FixedRate(rate) => {
249 for bit in individual.iter_mut() {
251 if rng.gen_bool(rate) {
252 *bit = !*bit;
253 }
254 }
255 }
256 MutationStrategy::Annealing(initial_rate, final_rate) => {
257 let progress = generation as f64 / max_generations as f64;
259 let current_rate = (final_rate - initial_rate).mul_add(progress, initial_rate);
260
261 for bit in individual.iter_mut() {
262 if rng.gen_bool(current_rate) {
263 *bit = !*bit;
264 }
265 }
266 }
267 MutationStrategy::Adaptive(min_rate, max_rate) => {
268 if let Some(diversity) = diversity {
270 let rate = (max_rate - min_rate).mul_add(1.0 - diversity, min_rate);
272
273 for bit in individual.iter_mut() {
274 if rng.gen_bool(rate) {
275 *bit = !*bit;
276 }
277 }
278 } else {
279 let rate = f64::midpoint(min_rate, max_rate);
281 for bit in individual.iter_mut() {
282 if rng.gen_bool(rate) {
283 *bit = !*bit;
284 }
285 }
286 }
287 }
288 }
289 }
290
291 fn calculate_diversity(&self, population: &[Vec<bool>]) -> f64 {
293 if population.len() <= 1 {
294 return 0.0;
295 }
296
297 let n_individuals = population.len();
298 let n_vars = population[0].len();
299 let mut sum_distances = 0;
300 let mut pair_count = 0;
301
302 for i in 0..n_individuals {
303 for j in (i + 1)..n_individuals {
304 let mut distance = 0;
305 for k in 0..n_vars {
306 if population[i][k] != population[j][k] {
307 distance += 1;
308 }
309 }
310 sum_distances += distance;
311 pair_count += 1;
312 }
313 }
314
315 if pair_count > 0 {
317 (sum_distances as f64) / (pair_count as f64 * n_vars as f64)
318 } else {
319 0.0
320 }
321 }
322}
323
324impl Sampler for GASampler {
325 fn run_hobo(
326 &self,
327 hobo: &(
328 Array<f64, scirs2_core::ndarray::IxDyn>,
329 HashMap<String, usize>,
330 ),
331 shots: usize,
332 ) -> SamplerResult<Vec<SampleResult>> {
333 let (tensor, var_map) = hobo;
335
336 let actual_shots = std::cmp::max(shots, 10);
338
339 let n_vars = var_map.len();
341
342 let idx_to_var: HashMap<usize, String> = var_map
344 .iter()
345 .map(|(var, &idx)| (idx, var.clone()))
346 .collect();
347
348 let mut rng = if let Some(seed) = self.seed {
350 StdRng::seed_from_u64(seed)
351 } else {
352 let seed: u64 = thread_rng().random();
353 StdRng::seed_from_u64(seed)
354 };
355
356 if self.population_size <= 2 || n_vars == 0 {
358 let mut assignments = HashMap::new();
360 for var in var_map.keys() {
361 assignments.insert(var.clone(), false);
362 }
363
364 return Ok(vec![SampleResult {
365 assignments,
366 energy: 0.0,
367 occurrences: 1,
368 }]);
369 }
370
371 if tensor.ndim() == 2 && tensor.shape() == [n_vars, n_vars] {
373 let matrix = tensor
375 .clone()
376 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
377 .map_err(|e| {
378 super::SamplerError::InvalidModel(format!(
379 "Failed to convert tensor to 2D matrix: {}",
380 e
381 ))
382 })?;
383 let qubo = (matrix, var_map.clone());
384
385 return self.run_qubo(&qubo, shots);
386 }
387
388 let evaluate_energy = |state: &[bool]| -> f64 {
391 let mut energy = 0.0;
392
393 if tensor.ndim() == 2 {
395 for i in 0..n_vars {
397 if state[i] {
398 energy += tensor[[i, i]]; for j in 0..n_vars {
401 if state[j] && j != i {
402 energy += tensor[[i, j]];
403 }
404 }
405 }
406 }
407 } else {
408 tensor.indexed_iter().for_each(|(indices, &coeff)| {
410 if coeff == 0.0 {
411 return;
412 }
413
414 let term_active = (0..indices.ndim())
416 .map(|d| indices[d])
417 .all(|idx| idx < state.len() && state[idx]);
418
419 if term_active {
420 energy += coeff;
421 }
422 });
423 }
424
425 energy
426 };
427
428 let mut solution_counts: HashMap<Vec<bool>, (f64, usize)> = HashMap::new();
430
431 let pop_size = self.population_size.clamp(10, 100);
433
434 let mut population: Vec<Vec<bool>> = (0..pop_size)
436 .map(|_| (0..n_vars).map(|_| rng.gen_bool(0.5)).collect())
437 .collect();
438
439 let mut fitness: Vec<f64> = population
441 .iter()
442 .map(|indiv| evaluate_energy(indiv))
443 .collect();
444
445 let mut best_solution = population[0].clone();
447 let mut best_fitness = fitness[0];
448
449 for (idx, fit) in fitness.iter().enumerate() {
450 if *fit < best_fitness {
451 best_fitness = *fit;
452 best_solution = population[idx].clone();
453 }
454 }
455
456 for _ in 0..30 {
458 let mut next_population = Vec::with_capacity(pop_size);
461
462 next_population.push(best_solution.clone());
464
465 while next_population.len() < pop_size {
467 let parent1_idx = tournament_selection(&fitness, 3, &mut rng);
469 let parent2_idx = tournament_selection(&fitness, 3, &mut rng);
470
471 let (mut child1, mut child2) =
473 simple_crossover(&population[parent1_idx], &population[parent2_idx], &mut rng);
474
475 mutate(&mut child1, 0.05, &mut rng);
477 mutate(&mut child2, 0.05, &mut rng);
478
479 next_population.push(child1);
481 if next_population.len() < pop_size {
482 next_population.push(child2);
483 }
484 }
485
486 population = next_population;
488 fitness = population
489 .iter()
490 .map(|indiv| evaluate_energy(indiv))
491 .collect();
492
493 for (idx, fit) in fitness.iter().enumerate() {
495 if *fit < best_fitness {
496 best_fitness = *fit;
497 best_solution = population[idx].clone();
498 }
499 }
500
501 for (idx, indiv) in population.iter().enumerate() {
503 let entry = solution_counts
504 .entry(indiv.clone())
505 .or_insert((fitness[idx], 0));
506 entry.1 += 1;
507 }
508 }
509
510 let mut results: Vec<SampleResult> = solution_counts
512 .into_iter()
513 .filter_map(|(state, (energy, count))| {
514 let assignments: HashMap<String, bool> = state
516 .iter()
517 .enumerate()
518 .filter_map(|(idx, &value)| {
519 idx_to_var
520 .get(&idx)
521 .map(|var_name| (var_name.clone(), value))
522 })
523 .collect();
524
525 if assignments.len() != state.len() {
527 return None;
528 }
529
530 Some(SampleResult {
531 assignments,
532 energy,
533 occurrences: count,
534 })
535 })
536 .collect();
537
538 results.sort_by(|a, b| {
541 a.energy
542 .partial_cmp(&b.energy)
543 .unwrap_or(std::cmp::Ordering::Equal)
544 });
545
546 if results.len() > actual_shots {
548 results.truncate(actual_shots);
549 }
550
551 Ok(results)
552 }
553
554 fn run_qubo(
555 &self,
556 qubo: &(
557 Array<f64, scirs2_core::ndarray::Ix2>,
558 HashMap<String, usize>,
559 ),
560 shots: usize,
561 ) -> SamplerResult<Vec<SampleResult>> {
562 let (matrix, var_map) = qubo;
564
565 let actual_shots = std::cmp::max(shots, 10);
567
568 let n_vars = var_map.len();
570
571 let idx_to_var: HashMap<usize, String> = var_map
573 .iter()
574 .map(|(var, &idx)| (idx, var.clone()))
575 .collect();
576
577 let mut rng = if let Some(seed) = self.seed {
579 StdRng::seed_from_u64(seed)
580 } else {
581 let seed: u64 = thread_rng().random();
582 StdRng::seed_from_u64(seed)
583 };
584
585 if self.population_size <= 2 || n_vars == 0 {
587 let mut assignments = HashMap::new();
588 for var in var_map.keys() {
589 assignments.insert(var.clone(), false);
590 }
591
592 return Ok(vec![SampleResult {
593 assignments,
594 energy: 0.0,
595 occurrences: 1,
596 }]);
597 }
598
599 let crossover_strategy = CrossoverStrategy::Adaptive;
601 let mutation_strategy = MutationStrategy::Annealing(0.1, 0.01);
602 let selection_pressure = 3; let use_elitism = true;
604
605 let mut population: Vec<Vec<bool>> = (0..self.population_size)
607 .map(|_| (0..n_vars).map(|_| rng.gen_bool(0.5)).collect())
608 .collect();
609
610 let mut fitness: Vec<f64> = population
612 .iter()
613 .map(|indiv| calculate_energy(indiv, matrix))
614 .collect();
615
616 let mut best_idx = 0;
618 let mut best_fitness = fitness[0];
619 for (idx, &fit) in fitness.iter().enumerate() {
620 if fit < best_fitness {
621 best_idx = idx;
622 best_fitness = fit;
623 }
624 }
625 let mut best_individual = population[best_idx].clone();
626 let mut best_individual_fitness = best_fitness;
627
628 let mut solution_counts: HashMap<Vec<bool>, usize> = HashMap::new();
630
631 for generation in 0..self.max_generations {
633 let diversity = self.calculate_diversity(&population);
635
636 let mut next_population = Vec::with_capacity(self.population_size);
638 let mut next_fitness = Vec::with_capacity(self.population_size);
639
640 if use_elitism {
642 next_population.push(best_individual.clone());
643 next_fitness.push(best_individual_fitness);
644 }
645
646 while next_population.len() < self.population_size {
648 let parent1_idx = tournament_selection(&fitness, selection_pressure, &mut rng);
650 let parent2_idx = tournament_selection(&fitness, selection_pressure, &mut rng);
651
652 let parent1 = &population[parent1_idx];
653 let parent2 = &population[parent2_idx];
654
655 let (mut child1, mut child2) =
657 self.crossover(parent1, parent2, crossover_strategy, &mut rng);
658
659 self.mutate(
661 &mut child1,
662 mutation_strategy,
663 generation,
664 self.max_generations,
665 Some(diversity),
666 &mut rng,
667 );
668 self.mutate(
669 &mut child2,
670 mutation_strategy,
671 generation,
672 self.max_generations,
673 Some(diversity),
674 &mut rng,
675 );
676
677 let child1_fitness = calculate_energy(&child1, matrix);
679 let child2_fitness = calculate_energy(&child2, matrix);
680
681 next_population.push(child1);
683 next_fitness.push(child1_fitness);
684
685 if next_population.len() < self.population_size {
687 next_population.push(child2);
688 next_fitness.push(child2_fitness);
689 }
690 }
691
692 population = next_population;
694 fitness = next_fitness;
695
696 best_idx = 0;
698 best_fitness = fitness[0];
699 for (idx, &fit) in fitness.iter().enumerate() {
700 if fit < best_fitness {
701 best_idx = idx;
702 best_fitness = fit;
703 }
704 }
705
706 if best_fitness < best_individual_fitness {
708 best_individual = population[best_idx].clone();
709 best_individual_fitness = best_fitness;
710 }
711
712 for individual in &population {
714 *solution_counts.entry(individual.clone()).or_insert(0) += 1;
715 }
716 }
717
718 let mut results = Vec::new();
720
721 for (solution, count) in &solution_counts {
723 if *count < 2 {
725 continue;
726 }
727
728 let energy = calculate_energy(solution, matrix);
730
731 let assignments: HashMap<String, bool> = solution
733 .iter()
734 .enumerate()
735 .filter_map(|(idx, &value)| {
736 idx_to_var
737 .get(&idx)
738 .map(|var_name| (var_name.clone(), value))
739 })
740 .collect();
741
742 if assignments.len() != solution.len() {
744 continue;
745 }
746
747 results.push(SampleResult {
749 assignments,
750 energy,
751 occurrences: *count,
752 });
753 }
754
755 results.sort_by(|a, b| {
758 a.energy
759 .partial_cmp(&b.energy)
760 .unwrap_or(std::cmp::Ordering::Equal)
761 });
762
763 if results.len() > actual_shots {
765 results.truncate(actual_shots);
766 }
767
768 Ok(results)
769 }
770}
771
772fn calculate_energy(solution: &[bool], matrix: &Array<f64, Ix2>) -> f64 {
774 let n = solution.len();
775 let mut energy = 0.0;
776
777 for i in 0..n {
779 if solution[i] {
780 energy += matrix[[i, i]];
781 }
782 }
783
784 for i in 0..n {
786 if solution[i] {
787 for j in (i + 1)..n {
788 if solution[j] {
789 energy += matrix[[i, j]];
790 }
791 }
792 }
793 }
794
795 energy
796}
797
798fn simple_crossover(
800 parent1: &[bool],
801 parent2: &[bool],
802 rng: &mut impl Rng,
803) -> (Vec<bool>, Vec<bool>) {
804 let n_vars = parent1.len();
805 let mut child1 = vec![false; n_vars];
806 let mut child2 = vec![false; n_vars];
807
808 let crossover_point = if n_vars > 1 {
810 rng.gen_range(1..n_vars)
811 } else {
812 0 };
814
815 for i in 0..n_vars {
816 if i < crossover_point {
817 child1[i] = parent1[i];
818 child2[i] = parent2[i];
819 } else {
820 child1[i] = parent2[i];
821 child2[i] = parent1[i];
822 }
823 }
824
825 (child1, child2)
826}
827
828fn mutate(individual: &mut [bool], rate: f64, rng: &mut impl Rng) {
830 for bit in individual.iter_mut() {
831 if rng.gen_bool(rate) {
832 *bit = !*bit;
833 }
834 }
835}
836
837fn tournament_selection(fitness: &[f64], tournament_size: usize, rng: &mut impl Rng) -> usize {
839 assert!(
841 !fitness.is_empty(),
842 "Cannot perform tournament selection on an empty fitness array"
843 );
844
845 if fitness.len() == 1 || tournament_size <= 1 {
846 return 0; }
848
849 let effective_tournament_size = std::cmp::min(tournament_size, fitness.len());
851
852 let mut best_idx = rng.gen_range(0..fitness.len());
853 let mut best_fitness = fitness[best_idx];
854
855 for _ in 1..(effective_tournament_size) {
856 let candidate_idx = rng.gen_range(0..fitness.len());
857 let candidate_fitness = fitness[candidate_idx];
858
859 if candidate_fitness < best_fitness {
861 best_idx = candidate_idx;
862 best_fitness = candidate_fitness;
863 }
864 }
865
866 best_idx
867}