1mod node;
5#[cfg(target_arch = "wasm32")]
6mod wasm;
7
8use node::{Node, NodeId};
9
10use std::collections::BTreeMap;
11
12#[derive(Default, Debug, Clone)]
13struct SolverState {
14 nodes: Vec<Node>,
15 header: NodeId,
16 column_sizes: Vec<usize>,
17}
18
19impl SolverState {
20 fn new_node(&mut self) -> NodeId {
21 self.nodes.push(Node::default());
22 NodeId::new(self.nodes.len() - 1)
23 }
24
25 fn link_horizontal(&mut self, left_id: NodeId, right_id: NodeId) {
26 let left = self.node_mut(left_id);
27 left.right = right_id;
28
29 let right = self.node_mut(right_id);
30 right.left = left_id;
31 }
32
33 fn detach_column(&mut self, node_id: NodeId) {
34 let node = self.node(node_id);
35 let header = self.node(node.header);
36
37 let header_left_id = header.left;
38 let header_right_id = header.right;
39
40 let header_left = self.node_mut(header_left_id);
41 header_left.right = header_right_id;
42
43 let header_right = self.node_mut(header_right_id);
44 header_right.left = header_left_id;
45 }
46
47 fn attach_column(&mut self, node_id: NodeId) {
48 let node = self.node_mut(node_id);
49 let header_id = node.header;
50
51 let header = self.node_mut(header_id);
52 let header_left_id = header.left;
53 let header_right_id = header.right;
54
55 let header_left = self.node_mut(header_left_id);
56 header_left.right = header_id;
57
58 let header_right = self.node_mut(header_right_id);
59 header_right.left = header_id;
60 }
61
62 fn detach_row(&mut self, node_id: NodeId) {
63 let mut current_id = self.node_mut(node_id).right;
64
65 loop {
66 if current_id == node_id {
67 break;
68 }
69
70 let current_node = self.node_mut(current_id);
71 let current_col_idx = current_node.col;
72 let current_down_id = current_node.down;
73 let current_up_id = current_node.up;
74 let current_right_id = current_node.right;
75
76 self.node_mut(current_up_id).down = current_down_id;
77 self.node_mut(current_down_id).up = current_up_id;
78
79 self.column_sizes[current_col_idx] -= 1;
80
81 current_id = current_right_id;
82 }
83 }
84
85 fn attach_row(&mut self, node_id: NodeId) {
86 let mut current_id = self.node_mut(node_id).left;
87
88 loop {
89 if current_id == node_id {
90 break;
91 }
92
93 let current_node = self.node_mut(current_id);
94 let current_col_idx = current_node.col;
95 let current_down_id = current_node.down;
96 let current_left_id = current_node.left;
97 let current_up_id = current_node.up;
98
99 self.column_sizes[current_col_idx] += 1;
100
101 self.node_mut(current_down_id).up = current_id;
102 self.node_mut(current_up_id).down = current_id;
103
104 current_id = current_left_id;
105 }
106 }
107
108 fn node_column_size(&self, id: NodeId) -> usize {
109 self.column_sizes[self.node(id).col]
110 }
111
112 fn node(&self, id: NodeId) -> &Node {
113 &self.nodes[id.value()]
114 }
115
116 fn node_mut(&mut self, id: NodeId) -> &mut Node {
117 &mut self.nodes[id.value()]
118 }
119
120 fn header_node_mut(&mut self, id: NodeId) -> &mut Node {
121 let header_node_id = self.node_mut(id).header;
122
123 self.node_mut(header_node_id)
124 }
125}
126
127#[derive(Debug, Copy, Clone)]
128struct Step {
129 node_id: NodeId,
130 backtracking: bool,
131}
132
133#[derive(Debug, Default, Clone)]
134pub struct Solver {
135 state: SolverState,
136 step_stack: Vec<Step>,
137 partial_solution: Vec<usize>,
138}
139
140impl Solver {
141 pub fn new(rows: Vec<Vec<usize>>, partial_solution: Vec<usize>) -> Self {
143 let column_count = rows.iter().flatten().copied().max().unwrap_or_default() + 1;
144
145 let mut state = SolverState {
146 nodes: vec![],
147 header: Default::default(),
148 column_sizes: vec![0; column_count],
149 };
150
151 let mut header_row: Vec<NodeId> = vec![];
152
153 let mut above_nodes = vec![NodeId::invalid(); column_count];
154
155 let mut columns_to_cover = BTreeMap::new();
156
157 for (row_idx, row) in rows.into_iter().enumerate() {
158 let mut first = NodeId::invalid();
159 let mut prev = NodeId::invalid();
160
161 for col_idx in row {
162 let node_id = state.new_node();
163
164 state.node_mut(node_id).row = row_idx as isize;
165 state.node_mut(node_id).col = col_idx;
166
167 state.column_sizes[col_idx] += 1;
168
169 if !first.is_valid() {
170 first = node_id;
171 }
172
173 if prev.is_valid() {
174 state.link_horizontal(prev, node_id);
175 }
176
177 let above_id = above_nodes[col_idx];
178 if above_id.is_valid() {
179 let above_node = state.node_mut(above_id);
180 let above_down_id = above_node.down;
181 let above_header_id = above_node.header;
182
183 above_node.down = node_id;
184
185 let node = state.node_mut(node_id);
186 node.up = above_id;
187 node.down = above_down_id;
188 node.header = above_header_id;
189
190 state.header_node_mut(node_id).up = node_id;
191 } else {
192 let header_id = state.new_node();
193 header_row.push(header_id);
194
195 let header = state.node_mut(header_id);
196 header.row = -1;
197 header.col = col_idx;
198 header.header = header_id;
199 header.up = node_id;
200 header.down = node_id;
201
202 let node = state.node_mut(node_id);
203 node.up = header_id;
204 node.down = header_id;
205 node.header = header_id;
206 }
207
208 above_nodes[col_idx] = node_id;
209 prev = node_id;
210
211 if partial_solution.contains(&col_idx) && !columns_to_cover.contains_key(&col_idx) {
212 columns_to_cover.insert(col_idx, node_id);
213 }
214 }
215
216 if first.is_valid() && prev.is_valid() {
217 state.link_horizontal(prev, first);
218 }
219 }
220
221 header_row.sort_by(|a, b| {
222 let a_col = state.node_mut(*a).col;
223 let b_col = state.node_mut(*b).col;
224 a_col.cmp(&b_col)
225 });
226
227 let Some(first_header_id) = header_row.first().copied() else {
228 return Default::default();
229 };
230
231 let last_header_id = header_row.iter().last().copied().unwrap_or(first_header_id);
232
233 state.node_mut(first_header_id).left = last_header_id;
234 state.node_mut(last_header_id).right = first_header_id;
235
236 header_row.windows(2).for_each(|nodes| {
237 state.link_horizontal(nodes[0], nodes[1]);
238 });
239
240 let header_root_id = state.new_node();
241
242 state.node_mut(header_root_id).right = first_header_id;
243 state.node_mut(first_header_id).left = header_root_id;
244
245 state.node_mut(header_root_id).left = last_header_id;
246 state.node_mut(last_header_id).right = header_root_id;
247
248 state.header = header_root_id;
249
250 let mut solver = Self {
251 state: state.clone(),
252 partial_solution: Vec::with_capacity(header_row.len()),
253 step_stack: vec![],
254 };
255
256 for column_node_id in columns_to_cover.values() {
257 let column_first_node_id = state.header_node_mut(*column_node_id).down;
258 solver.cover(column_first_node_id);
259 }
260
261 if let Some(node_id) = solver.choose_column() {
262 solver.step_stack.push(Step {
263 node_id,
264 backtracking: false,
265 });
266 }
267
268 solver
269 }
270
271 fn choose_column(&self) -> Option<NodeId> {
272 let mut best_column_id = None;
273 let mut best_size = usize::MAX;
274
275 let mut current_node_id = self.state.node(self.state.header).right;
276
277 while current_node_id != self.state.header {
278 let current_size = self.state.node_column_size(current_node_id);
279
280 if current_size < best_size {
281 best_column_id = Some(current_node_id);
282 best_size = current_size;
283 }
284 current_node_id = self.state.node(current_node_id).right;
285 }
286
287 Some(self.state.node(best_column_id?).down)
288 }
289
290 pub fn partial_solution(&self) -> &[usize] {
291 &self.partial_solution
292 }
293
294 pub fn is_completed(&self) -> bool {
295 self.step_stack.is_empty()
296 }
297
298 fn cover(&mut self, node_id: NodeId) {
299 self.state.detach_column(node_id);
300
301 let node = self.state.node_mut(node_id);
302 let node_header_id = node.header;
303
304 let mut down_id = self.state.node_mut(node_header_id).down;
305 while down_id != node_header_id {
306 self.state.detach_row(down_id);
307
308 down_id = self.state.node_mut(down_id).down;
309 }
310 }
311
312 fn uncover(&mut self, node_id: NodeId) {
313 let node_header_id = self.state.node(node_id).header;
314 let mut up_id = self.state.node(node_header_id).up;
315
316 while up_id != node_header_id {
317 self.state.attach_row(up_id);
318 up_id = self.state.node(up_id).up;
319 }
320
321 self.state.attach_column(node_id);
322 }
323
324 pub fn step(&mut self) -> Option<Vec<usize>> {
325 let Step {
326 node_id,
327 backtracking,
328 } = self.step_stack.pop()?;
329
330 let node_header_id = self.state.node(node_id).header;
331
332 if node_id == node_header_id {
333 return None;
334 }
335
336 if backtracking {
337 self.step_backward(node_id);
338 } else {
339 self.step_forward(node_id);
340 }
341
342 let header_root_id = self.state.header;
343
344 if self.state.node_mut(header_root_id).right == header_root_id {
345 Some(self.partial_solution.clone())
346 } else {
347 None
348 }
349 }
350
351 fn step_forward(&mut self, node_id: NodeId) {
352 let node_row = self.state.node(node_id).row;
353 self.partial_solution.push(node_row as _);
354
355 let mut current_id = node_id;
356 loop {
357 self.cover(current_id);
358
359 current_id = self.state.node(current_id).right;
360 if current_id == node_id {
361 break;
362 }
363 }
364
365 self.step_stack.push(Step {
366 node_id,
367 backtracking: true,
368 });
369
370 if let Some(node_id) = self.choose_column() {
371 self.step_stack.push(Step {
372 node_id,
373 backtracking: false,
374 });
375 }
376 }
377
378 fn step_backward(&mut self, node_id: NodeId) {
379 self.partial_solution.pop();
380
381 let mut current_id = self.state.node(node_id).left;
382 loop {
383 self.uncover(current_id);
384
385 if current_id == node_id {
386 break;
387 }
388 current_id = self.state.node(current_id).left;
389 }
390
391 let node_down = self.state.node(node_id).down;
392 let node_header = self.state.node(node_id).header;
393
394 if node_down != node_header {
395 self.step_stack.push(Step {
396 node_id: node_down,
397 backtracking: false,
398 });
399 }
400 }
401}
402
403impl Iterator for Solver {
404 type Item = Vec<usize>;
405
406 fn next(&mut self) -> Option<Self::Item> {
407 while !self.is_completed() {
408 let step = self.step();
409
410 if step.is_some() {
411 return step;
412 }
413 }
414
415 None
416 }
417}
418
419#[cfg(test)]
420#[rustfmt::skip]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_basic_solve() {
426 let solver = Solver::new(vec![
438 vec![0, 1],
439 vec![0, 2],
440 vec![1, 3],
441 vec![2, 3],
442 vec![0, 1, 2],
443 vec![1, 2, 3],
444 ], vec![0, 2]);
445
446 let solutions = solver.collect::<Vec<_>>();
447
448 assert_eq!(vec![vec![2]], solutions);
449 }
450}