1use super::types::*;
4use crate::sampler::{SampleResult, Sampler};
5use scirs2_core::ndarray::Array2;
6use std::collections::HashMap;
7
8pub struct HierarchicalSolver<S: Sampler> {
10 base_sampler: S,
12 strategy: HierarchicalStrategy,
14 coarsening: CoarseningStrategy,
16 min_problem_size: usize,
18 max_levels: usize,
20 refinement_iterations: usize,
22}
23
24impl<S: Sampler> HierarchicalSolver<S> {
25 pub const fn new(base_sampler: S) -> Self {
27 Self {
28 base_sampler,
29 strategy: HierarchicalStrategy::CoarsenSolve,
30 coarsening: CoarseningStrategy::VariableClustering,
31 min_problem_size: 10,
32 max_levels: 10,
33 refinement_iterations: 5,
34 }
35 }
36
37 pub const fn with_strategy(mut self, strategy: HierarchicalStrategy) -> Self {
39 self.strategy = strategy;
40 self
41 }
42
43 pub const fn with_coarsening(mut self, coarsening: CoarseningStrategy) -> Self {
45 self.coarsening = coarsening;
46 self
47 }
48
49 pub const fn with_min_problem_size(mut self, size: usize) -> Self {
51 self.min_problem_size = size;
52 self
53 }
54
55 pub fn solve(
57 &mut self,
58 qubo: &Array2<f64>,
59 var_map: &HashMap<String, usize>,
60 shots: usize,
61 ) -> Result<SampleResult, String> {
62 match self.strategy {
63 HierarchicalStrategy::CoarsenSolve => self.coarsen_solve_approach(qubo, var_map, shots),
64 HierarchicalStrategy::MultiGrid => self.multigrid_approach(qubo, var_map, shots),
65 HierarchicalStrategy::VCycle => self.v_cycle_approach(qubo, var_map, shots),
66 }
67 }
68
69 fn coarsen_solve_approach(
71 &mut self,
72 qubo: &Array2<f64>,
73 var_map: &HashMap<String, usize>,
74 shots: usize,
75 ) -> Result<SampleResult, String> {
76 let hierarchy = self.build_hierarchy(qubo, var_map)?;
78
79 let coarsest_level = hierarchy.levels.last().ok_or("Empty hierarchy")?;
81
82 let coarse_results = self
83 .base_sampler
84 .run_qubo(
85 &(coarsest_level.qubo.clone(), coarsest_level.var_map.clone()),
86 shots,
87 )
88 .map_err(|e| format!("Sampler error: {e:?}"))?;
89
90 let coarse_result = coarse_results
92 .into_iter()
93 .next()
94 .ok_or_else(|| "No solutions found".to_string())?;
95
96 self.refine_through_hierarchy(&hierarchy, coarse_result)
98 }
99
100 fn multigrid_approach(
102 &mut self,
103 qubo: &Array2<f64>,
104 var_map: &HashMap<String, usize>,
105 shots: usize,
106 ) -> Result<SampleResult, String> {
107 let initial_results = self
108 .base_sampler
109 .run_qubo(&(qubo.clone(), var_map.clone()), shots / 4)
110 .map_err(|e| format!("Initial sampler error: {e:?}"))?;
111
112 let mut current_solution = initial_results
114 .into_iter()
115 .next()
116 .ok_or_else(|| "No initial solutions found".to_string())?;
117
118 for _cycle in 0..3 {
120 current_solution =
121 self.v_cycle_refinement(qubo, var_map, ¤t_solution, shots / 4)?;
122 }
123
124 Ok(current_solution)
125 }
126
127 fn v_cycle_approach(
129 &mut self,
130 qubo: &Array2<f64>,
131 var_map: &HashMap<String, usize>,
132 shots: usize,
133 ) -> Result<SampleResult, String> {
134 self.v_cycle_refinement(qubo, var_map, &SampleResult::default(), shots)
135 }
136
137 fn v_cycle_refinement(
139 &mut self,
140 qubo: &Array2<f64>,
141 var_map: &HashMap<String, usize>,
142 initial_solution: &SampleResult,
143 shots: usize,
144 ) -> Result<SampleResult, String> {
145 let hierarchy = self.build_hierarchy(qubo, var_map)?;
147
148 let mut current_solution = initial_solution.clone();
150
151 for level in 1..hierarchy.levels.len() {
153 current_solution = self.restrict_solution(&hierarchy, level - 1, ¤t_solution)?;
154 }
155
156 if let Some(coarsest_level) = hierarchy.levels.last() {
158 let coarse_results = self
159 .base_sampler
160 .run_qubo(
161 &(coarsest_level.qubo.clone(), coarsest_level.var_map.clone()),
162 shots,
163 )
164 .map_err(|e| format!("Coarse sampler error: {e:?}"))?;
165
166 current_solution = coarse_results
168 .into_iter()
169 .next()
170 .ok_or_else(|| "No coarse solutions found".to_string())?;
171 }
172
173 for level in (0..hierarchy.levels.len() - 1).rev() {
175 current_solution =
176 self.interpolate_and_refine(&hierarchy, level, ¤t_solution, shots / 4)?;
177 }
178
179 Ok(current_solution)
180 }
181
182 fn build_hierarchy(
184 &self,
185 qubo: &Array2<f64>,
186 var_map: &HashMap<String, usize>,
187 ) -> Result<Hierarchy, String> {
188 let mut levels = Vec::new();
189 let mut projections = Vec::new();
190
191 let mut current_qubo = qubo.clone();
192 let mut current_var_map = var_map.clone();
193 let mut current_size = current_qubo.shape()[0];
194 let mut level = 0;
195
196 levels.push(HierarchyLevel {
198 level,
199 qubo: current_qubo.clone(),
200 var_map: current_var_map.clone(),
201 size: current_size,
202 });
203
204 while current_size > self.min_problem_size && level < self.max_levels {
206 let (coarse_qubo, coarse_var_map, projection) =
207 self.coarsen_problem(¤t_qubo, ¤t_var_map)?;
208
209 current_qubo = coarse_qubo;
210 current_var_map = coarse_var_map;
211 current_size = current_qubo.shape()[0];
212 level += 1;
213
214 levels.push(HierarchyLevel {
215 level,
216 qubo: current_qubo.clone(),
217 var_map: current_var_map.clone(),
218 size: current_size,
219 });
220
221 projections.push(projection);
222 }
223
224 Ok(Hierarchy {
225 levels,
226 projections,
227 })
228 }
229
230 fn coarsen_problem(
232 &self,
233 qubo: &Array2<f64>,
234 var_map: &HashMap<String, usize>,
235 ) -> Result<(Array2<f64>, HashMap<String, usize>, Projection), String> {
236 match self.coarsening {
237 CoarseningStrategy::VariableClustering => {
238 self.variable_clustering_coarsen(qubo, var_map)
239 }
240 _ => {
241 self.variable_clustering_coarsen(qubo, var_map)
243 }
244 }
245 }
246
247 fn variable_clustering_coarsen(
249 &self,
250 qubo: &Array2<f64>,
251 _var_map: &HashMap<String, usize>,
252 ) -> Result<(Array2<f64>, HashMap<String, usize>, Projection), String> {
253 let n = qubo.shape()[0];
254
255 let mut clusters = Vec::new();
257 let mut assigned = vec![false; n];
258
259 for i in 0..n {
260 if !assigned[i] {
261 let mut cluster = vec![i];
262 assigned[i] = true;
263
264 for j in i + 1..n {
266 if !assigned[j] && qubo[[i, j]].abs() > 0.5 {
267 cluster.push(j);
268 assigned[j] = true;
269 }
270 }
271
272 clusters.push(cluster);
273 }
274 }
275
276 let num_clusters = clusters.len();
278 let mut coarse_qubo = Array2::zeros((num_clusters, num_clusters));
279
280 for (ci, cluster_i) in clusters.iter().enumerate() {
281 for (cj, cluster_j) in clusters.iter().enumerate() {
282 let mut weight = 0.0;
283
284 for &i in cluster_i {
285 for &j in cluster_j {
286 weight += qubo[[i, j]];
287 }
288 }
289
290 coarse_qubo[[ci, cj]] = weight;
291 }
292 }
293
294 let mut coarse_var_map = HashMap::new();
296 for (ci, _cluster) in clusters.iter().enumerate() {
297 let var_name = format!("cluster_{ci}");
298 coarse_var_map.insert(var_name, ci);
299 }
300
301 let projection = Projection {
303 fine_to_coarse: (0..n)
304 .map(|i| clusters.iter().position(|c| c.contains(&i)).unwrap_or(0))
305 .collect(),
306 coarse_to_fine: clusters,
307 };
308
309 Ok((coarse_qubo, coarse_var_map, projection))
310 }
311
312 fn refine_through_hierarchy(
314 &mut self,
315 hierarchy: &Hierarchy,
316 coarse_solution: SampleResult,
317 ) -> Result<SampleResult, String> {
318 let mut current_solution = coarse_solution;
319
320 for level in (0..hierarchy.levels.len() - 1).rev() {
322 current_solution = self.interpolate_solution(hierarchy, level, ¤t_solution)?;
323
324 for _iter in 0..self.refinement_iterations {
326 current_solution = self.refine_solution(
327 &hierarchy.levels[level].qubo,
328 &hierarchy.levels[level].var_map,
329 ¤t_solution,
330 10, )?;
332 }
333 }
334
335 Ok(current_solution)
336 }
337
338 fn restrict_solution(
340 &self,
341 hierarchy: &Hierarchy,
342 level: usize,
343 solution: &SampleResult,
344 ) -> Result<SampleResult, String> {
345 if level >= hierarchy.projections.len() {
346 return Ok(solution.clone());
347 }
348
349 let projection = &hierarchy.projections[level];
350 let coarse_level = &hierarchy.levels[level + 1];
351
352 let restricted_solution = SampleResult::default();
354
355 for (var_name, &coarse_idx) in &coarse_level.var_map {
356 let fine_vars = &projection.coarse_to_fine[coarse_idx];
358
359 let mut votes = 0i32;
361 for &fine_idx in fine_vars {
362 if let Some(fine_var_name) = hierarchy.levels[level]
363 .var_map
364 .iter()
365 .find(|(_, &idx)| idx == fine_idx)
366 .map(|(name, _)| name)
367 {
368 if let Some(sample) = solution.best_sample() {
369 if let Some(&value) = sample.get(fine_var_name) {
370 votes += if value { 1 } else { -1 };
371 }
372 }
373 }
374 }
375
376 if let Some(mut best_sample) = restricted_solution.best_sample().cloned() {
378 best_sample.insert(var_name.clone(), votes > 0);
379 } else {
380 let mut new_sample = HashMap::new();
381 new_sample.insert(var_name.clone(), votes > 0);
382 }
385 }
386
387 Ok(restricted_solution)
388 }
389
390 fn interpolate_solution(
392 &self,
393 hierarchy: &Hierarchy,
394 level: usize,
395 coarse_solution: &SampleResult,
396 ) -> Result<SampleResult, String> {
397 if level >= hierarchy.projections.len() {
398 return Ok(coarse_solution.clone());
399 }
400
401 let projection = &hierarchy.projections[level];
402 let fine_level = &hierarchy.levels[level];
403 let coarse_level = &hierarchy.levels[level + 1];
404
405 let interpolated_solution = SampleResult::default();
406
407 for (fine_var_name, &fine_idx) in &fine_level.var_map {
409 let coarse_idx = projection.fine_to_coarse[fine_idx];
410
411 if let Some((coarse_var_name, _)) = coarse_level
413 .var_map
414 .iter()
415 .find(|(_, &idx)| idx == coarse_idx)
416 {
417 if let Some(coarse_sample) = coarse_solution.best_sample() {
418 if let Some(&coarse_value) = coarse_sample.get(coarse_var_name) {
419 if let Some(mut fine_sample) = interpolated_solution.best_sample().cloned()
421 {
422 fine_sample.insert(fine_var_name.clone(), coarse_value);
423 } else {
424 let mut new_sample = HashMap::new();
425 new_sample.insert(fine_var_name.clone(), coarse_value);
426 }
428 }
429 }
430 }
431 }
432
433 Ok(interpolated_solution)
434 }
435
436 fn interpolate_and_refine(
438 &mut self,
439 hierarchy: &Hierarchy,
440 level: usize,
441 coarse_solution: &SampleResult,
442 shots: usize,
443 ) -> Result<SampleResult, String> {
444 let mut interpolated = self.interpolate_solution(hierarchy, level, coarse_solution)?;
446
447 for _iter in 0..self.refinement_iterations {
449 interpolated = self.refine_solution(
450 &hierarchy.levels[level].qubo,
451 &hierarchy.levels[level].var_map,
452 &interpolated,
453 shots,
454 )?;
455 }
456
457 Ok(interpolated)
458 }
459
460 fn refine_solution(
462 &mut self,
463 qubo: &Array2<f64>,
464 var_map: &HashMap<String, usize>,
465 _current_solution: &SampleResult,
466 shots: usize,
467 ) -> Result<SampleResult, String> {
468 let results = self
471 .base_sampler
472 .run_qubo(&(qubo.clone(), var_map.clone()), shots)
473 .map_err(|e| format!("Refinement sampler error: {e:?}"))?;
474
475 results
477 .into_iter()
478 .next()
479 .ok_or_else(|| "No refinement results found".to_string())
480 }
481}
482
483impl Default for SampleResult {
484 fn default() -> Self {
485 Self {
488 assignments: HashMap::new(),
489 energy: 0.0,
490 occurrences: 0,
491 }
492 }
493}
494
495impl SampleResult {
496 pub const fn best_sample(&self) -> Option<&HashMap<String, bool>> {
498 Some(&self.assignments)
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::sampler::simulated_annealing::SASampler;
506 use scirs2_core::ndarray::Array2;
507 use std::collections::HashMap;
508
509 #[test]
510 fn test_hierarchical_solver_creation() {
511 let base_sampler = SASampler::new(None);
512 let solver = HierarchicalSolver::new(base_sampler);
513
514 assert_eq!(solver.min_problem_size, 10);
516 assert_eq!(solver.max_levels, 10);
517 }
518
519 #[test]
520 fn test_hierarchy_building() {
521 let base_sampler = SASampler::new(None);
522 let solver = HierarchicalSolver::new(base_sampler);
523
524 let mut qubo = Array2::from_shape_vec(
526 (4, 4),
527 vec![
528 1.0, 0.5, 0.1, 0.0, 0.5, 1.0, 0.0, 0.1, 0.1, 0.0, 1.0, 0.5, 0.0, 0.1, 0.5, 1.0,
529 ],
530 )
531 .expect("QUBO matrix construction should succeed");
532
533 let mut var_map = HashMap::new();
534 for i in 0..4 {
535 var_map.insert(format!("x{i}"), i);
536 }
537
538 let hierarchy = solver.build_hierarchy(&qubo, &var_map);
539 assert!(hierarchy.is_ok());
540
541 let h = hierarchy.expect("Hierarchy building should succeed");
542 assert!(!h.levels.is_empty());
543 assert_eq!(h.levels[0].size, 4); }
545}