1use crate::cube::*;
6use crate::pruning::PruningTables;
7use std::collections::HashMap;
8use std::collections::HashSet;
9
10use priority_queue::PriorityQueue;
11
12pub trait Solver {
16 fn get_start_state(&self) -> &CubeState;
18
19 fn solve(&self) -> MoveSequence;
22}
23
24pub struct AStarSolver {
32 start_state: CubeState,
33}
34
35impl AStarSolver {
36 pub fn new(state: CubeState) -> Self {
37 AStarSolver { start_state: state }
38 }
39}
40
41impl Solver for AStarSolver {
42 fn get_start_state(&self) -> &CubeState {
43 &self.start_state
44 }
45
46 fn solve(&self) -> MoveSequence {
47 let mut queue = PriorityQueue::new();
48 let mut visited = HashSet::<CubeState>::new();
49 let mut come_from = HashMap::<CubeState, (CubeState, MoveInstance)>::new();
50 let mut g_scores = HashMap::<CubeState, i32>::new();
51
52 queue.push(self.get_start_state().clone(), 0);
54 g_scores.insert(self.get_start_state().clone(), 0);
55 while queue.len() > 0 {
56 if let Some((current, priority)) = queue.pop() {
57 if current == CubeState::default() {
58 break;
60 }
61 if visited.contains(¤t) {
62 continue;
63 }
64 visited.insert(current.clone());
65 for m in ALL_MOVES.iter() {
67 let new_state = current.apply_move_instance(m);
68 let new_g_score = priority - 1;
69 let neighbor_g_score = g_scores.get(&new_state).unwrap_or(&std::i32::MIN);
70 if new_g_score > *neighbor_g_score {
71 come_from.insert(new_state.clone(), (current.clone(), *m));
72 g_scores.insert(new_state.clone(), new_g_score);
73 }
74 if let None = queue.get(&new_state) {
75 queue.push(new_state, priority - 1);
76 } else if let Some((_, p)) = queue.get(&new_state) {
77 if *p < priority - 1 {
78 queue.push(new_state, priority - 1);
79 }
80 }
81 }
82 }
83 }
84 let mut curr = CubeState::default();
86 let mut path = vec![];
87 while curr != self.get_start_state().clone() {
88 if let Some((c, m)) = come_from.get(&curr) {
89 path.push(m.clone());
90 curr = c.clone();
91 }
92 }
93 path.reverse();
94 MoveSequence(path)
95 }
96}
97
98pub struct IDASolver<'a> {
106 start_state: CubeState,
107 pruning_tables: &'a PruningTables,
108}
109
110enum SearchResult {
111 Found,
112 NewBound(u8),
113}
114
115impl<'a> IDASolver<'a> {
116 pub fn new(state: CubeState, tables: &'a PruningTables) -> Self {
117 Self {
118 start_state: state,
119 pruning_tables: tables,
120 }
121 }
122
123 fn search_for_solution(
124 &self,
125 mut curr_path: &mut MoveSequence,
126 last_state: &CubeState,
127 g: u8,
128 bound: u8,
129 ) -> SearchResult {
130 let last_h = self.pruning_tables.compute_h_value(&last_state);
131 let f = g + last_h;
132 if f > bound {
133 SearchResult::NewBound(f)
134 } else if *last_state == CubeState::default() {
135 SearchResult::Found
137 } else {
138 let mut min = std::u8::MAX;
139 let allowed_moves = allowed_moves_after_seq(&curr_path);
140 for m in ALL_MOVES
141 .iter()
142 .filter(|mo| ((1 << get_basemove_pos(mo.basemove)) & allowed_moves) == 0)
143 {
144 if curr_path.get_moves().len() > 0 {
145 let path = curr_path.get_moves_mut();
146 let last_move = path[path.len() - 1];
147 if last_move.basemove == m.basemove {
148 continue;
149 }
150 }
151 curr_path.get_moves_mut().push(*m);
152 let next_state = last_state.apply_move_instance(m);
153 let t = self.search_for_solution(&mut curr_path, &next_state, g + 1, bound);
154 match t {
155 SearchResult::Found => return SearchResult::Found,
156 SearchResult::NewBound(b) => {
157 min = std::cmp::min(b, min);
158 }
159 };
160 curr_path.get_moves_mut().pop();
161 }
162 SearchResult::NewBound(min)
163 }
164 }
165}
166
167impl Solver for IDASolver<'_> {
168 fn get_start_state(&self) -> &CubeState {
169 &self.start_state
170 }
171
172 fn solve(&self) -> MoveSequence {
173 let start_state = self.get_start_state();
174
175 let mut bound = self.pruning_tables.compute_h_value(&start_state);
177 let mut path: MoveSequence = MoveSequence(vec![]);
178 loop {
179 println!("Searching depth {}...", bound);
180 match self.search_for_solution(&mut path, &start_state, 0, bound) {
181 SearchResult::Found => {
182 break;
183 }
184 SearchResult::NewBound(t) => {
185 bound = t;
186 }
187 }
188 }
189 path
190 }
191}