algx/
lib.rs

1//! Implementation of [Knuth's Algorithm X](https://en.wikipedia.org/wiki/Knuth%27s_Algorithm_X)
2//! for solving the [exact cover](https://en.wikipedia.org/wiki/Exact_cover) problem.
3//!
4mod node;
5#[cfg(target_arch = "wasm32")]
6mod wasm;
7
8use node::{Node, NodeId};
9
10use std::collections::BTreeMap;
11
12#[derive(Default, Debug, Clone)]
13struct SolverState {
14    nodes: Vec<Node>,
15    header: NodeId,
16    column_sizes: Vec<usize>,
17}
18
19impl SolverState {
20    fn new_node(&mut self) -> NodeId {
21        self.nodes.push(Node::default());
22        NodeId::new(self.nodes.len() - 1)
23    }
24
25    fn link_horizontal(&mut self, left_id: NodeId, right_id: NodeId) {
26        let left = self.node_mut(left_id);
27        left.right = right_id;
28
29        let right = self.node_mut(right_id);
30        right.left = left_id;
31    }
32
33    fn detach_column(&mut self, node_id: NodeId) {
34        let node = self.node(node_id);
35        let header = self.node(node.header);
36
37        let header_left_id = header.left;
38        let header_right_id = header.right;
39
40        let header_left = self.node_mut(header_left_id);
41        header_left.right = header_right_id;
42
43        let header_right = self.node_mut(header_right_id);
44        header_right.left = header_left_id;
45    }
46
47    fn attach_column(&mut self, node_id: NodeId) {
48        let node = self.node_mut(node_id);
49        let header_id = node.header;
50
51        let header = self.node_mut(header_id);
52        let header_left_id = header.left;
53        let header_right_id = header.right;
54
55        let header_left = self.node_mut(header_left_id);
56        header_left.right = header_id;
57
58        let header_right = self.node_mut(header_right_id);
59        header_right.left = header_id;
60    }
61
62    fn detach_row(&mut self, node_id: NodeId) {
63        let mut current_id = self.node_mut(node_id).right;
64
65        loop {
66            if current_id == node_id {
67                break;
68            }
69
70            let current_node = self.node_mut(current_id);
71            let current_col_idx = current_node.col;
72            let current_down_id = current_node.down;
73            let current_up_id = current_node.up;
74            let current_right_id = current_node.right;
75
76            self.node_mut(current_up_id).down = current_down_id;
77            self.node_mut(current_down_id).up = current_up_id;
78
79            self.column_sizes[current_col_idx] -= 1;
80
81            current_id = current_right_id;
82        }
83    }
84
85    fn attach_row(&mut self, node_id: NodeId) {
86        let mut current_id = self.node_mut(node_id).left;
87
88        loop {
89            if current_id == node_id {
90                break;
91            }
92
93            let current_node = self.node_mut(current_id);
94            let current_col_idx = current_node.col;
95            let current_down_id = current_node.down;
96            let current_left_id = current_node.left;
97            let current_up_id = current_node.up;
98
99            self.column_sizes[current_col_idx] += 1;
100
101            self.node_mut(current_down_id).up = current_id;
102            self.node_mut(current_up_id).down = current_id;
103
104            current_id = current_left_id;
105        }
106    }
107
108    fn node_column_size(&self, id: NodeId) -> usize {
109        self.column_sizes[self.node(id).col]
110    }
111
112    fn node(&self, id: NodeId) -> &Node {
113        &self.nodes[id.value()]
114    }
115
116    fn node_mut(&mut self, id: NodeId) -> &mut Node {
117        &mut self.nodes[id.value()]
118    }
119
120    fn header_node_mut(&mut self, id: NodeId) -> &mut Node {
121        let header_node_id = self.node_mut(id).header;
122
123        self.node_mut(header_node_id)
124    }
125}
126
127#[derive(Debug, Copy, Clone)]
128struct Step {
129    node_id: NodeId,
130    backtracking: bool,
131}
132
133#[derive(Debug, Default, Clone)]
134pub struct Solver {
135    state: SolverState,
136    step_stack: Vec<Step>,
137    partial_solution: Vec<usize>,
138}
139
140impl Solver {
141    /// Creates a new solver for given rows. Columns in the rows are assumed to be in ascending order
142    pub fn new(rows: Vec<Vec<usize>>, partial_solution: Vec<usize>) -> Self {
143        let column_count = rows.iter().flatten().copied().max().unwrap_or_default() + 1;
144
145        let mut state = SolverState {
146            nodes: vec![],
147            header: Default::default(),
148            column_sizes: vec![0; column_count],
149        };
150
151        let mut header_row: Vec<NodeId> = vec![];
152
153        let mut above_nodes = vec![NodeId::invalid(); column_count];
154
155        let mut columns_to_cover = BTreeMap::new();
156
157        for (row_idx, row) in rows.into_iter().enumerate() {
158            let mut first = NodeId::invalid();
159            let mut prev = NodeId::invalid();
160
161            for col_idx in row {
162                let node_id = state.new_node();
163
164                state.node_mut(node_id).row = row_idx as isize;
165                state.node_mut(node_id).col = col_idx;
166
167                state.column_sizes[col_idx] += 1;
168
169                if !first.is_valid() {
170                    first = node_id;
171                }
172
173                if prev.is_valid() {
174                    state.link_horizontal(prev, node_id);
175                }
176
177                let above_id = above_nodes[col_idx];
178                if above_id.is_valid() {
179                    let above_node = state.node_mut(above_id);
180                    let above_down_id = above_node.down;
181                    let above_header_id = above_node.header;
182
183                    above_node.down = node_id;
184
185                    let node = state.node_mut(node_id);
186                    node.up = above_id;
187                    node.down = above_down_id;
188                    node.header = above_header_id;
189
190                    state.header_node_mut(node_id).up = node_id;
191                } else {
192                    let header_id = state.new_node();
193                    header_row.push(header_id);
194
195                    let header = state.node_mut(header_id);
196                    header.row = -1;
197                    header.col = col_idx;
198                    header.header = header_id;
199                    header.up = node_id;
200                    header.down = node_id;
201
202                    let node = state.node_mut(node_id);
203                    node.up = header_id;
204                    node.down = header_id;
205                    node.header = header_id;
206                }
207
208                above_nodes[col_idx] = node_id;
209                prev = node_id;
210
211                if partial_solution.contains(&col_idx) && !columns_to_cover.contains_key(&col_idx) {
212                    columns_to_cover.insert(col_idx, node_id);
213                }
214            }
215
216            if first.is_valid() && prev.is_valid() {
217                state.link_horizontal(prev, first);
218            }
219        }
220
221        header_row.sort_by(|a, b| {
222            let a_col = state.node_mut(*a).col;
223            let b_col = state.node_mut(*b).col;
224            a_col.cmp(&b_col)
225        });
226
227        let Some(first_header_id) = header_row.first().copied() else {
228            return Default::default();
229        };
230
231        let last_header_id = header_row.iter().last().copied().unwrap_or(first_header_id);
232
233        state.node_mut(first_header_id).left = last_header_id;
234        state.node_mut(last_header_id).right = first_header_id;
235
236        header_row.windows(2).for_each(|nodes| {
237            state.link_horizontal(nodes[0], nodes[1]);
238        });
239
240        let header_root_id = state.new_node();
241
242        state.node_mut(header_root_id).right = first_header_id;
243        state.node_mut(first_header_id).left = header_root_id;
244
245        state.node_mut(header_root_id).left = last_header_id;
246        state.node_mut(last_header_id).right = header_root_id;
247
248        state.header = header_root_id;
249
250        let mut solver = Self {
251            state: state.clone(),
252            partial_solution: Vec::with_capacity(header_row.len()),
253            step_stack: vec![],
254        };
255
256        for column_node_id in columns_to_cover.values() {
257            let column_first_node_id = state.header_node_mut(*column_node_id).down;
258            solver.cover(column_first_node_id);
259        }
260
261        if let Some(node_id) = solver.choose_column() {
262            solver.step_stack.push(Step {
263                node_id,
264                backtracking: false,
265            });
266        }
267
268        solver
269    }
270
271    fn choose_column(&self) -> Option<NodeId> {
272        let mut best_column_id = None;
273        let mut best_size = usize::MAX;
274
275        let mut current_node_id = self.state.node(self.state.header).right;
276
277        while current_node_id != self.state.header {
278            let current_size = self.state.node_column_size(current_node_id);
279
280            if current_size < best_size {
281                best_column_id = Some(current_node_id);
282                best_size = current_size;
283            }
284            current_node_id = self.state.node(current_node_id).right;
285        }
286
287        Some(self.state.node(best_column_id?).down)
288    }
289
290    pub fn partial_solution(&self) -> &[usize] {
291        &self.partial_solution
292    }
293
294    pub fn is_completed(&self) -> bool {
295        self.step_stack.is_empty()
296    }
297
298    fn cover(&mut self, node_id: NodeId) {
299        self.state.detach_column(node_id);
300
301        let node = self.state.node_mut(node_id);
302        let node_header_id = node.header;
303
304        let mut down_id = self.state.node_mut(node_header_id).down;
305        while down_id != node_header_id {
306            self.state.detach_row(down_id);
307
308            down_id = self.state.node_mut(down_id).down;
309        }
310    }
311
312    fn uncover(&mut self, node_id: NodeId) {
313        let node_header_id = self.state.node(node_id).header;
314        let mut up_id = self.state.node(node_header_id).up;
315
316        while up_id != node_header_id {
317            self.state.attach_row(up_id);
318            up_id = self.state.node(up_id).up;
319        }
320
321        self.state.attach_column(node_id);
322    }
323
324    pub fn step(&mut self) -> Option<Vec<usize>> {
325        let Step {
326            node_id,
327            backtracking,
328        } = self.step_stack.pop()?;
329
330        let node_header_id = self.state.node(node_id).header;
331
332        if node_id == node_header_id {
333            return None;
334        }
335
336        if backtracking {
337            self.step_backward(node_id);
338        } else {
339            self.step_forward(node_id);
340        }
341
342        let header_root_id = self.state.header;
343
344        if self.state.node_mut(header_root_id).right == header_root_id {
345            Some(self.partial_solution.clone())
346        } else {
347            None
348        }
349    }
350
351    fn step_forward(&mut self, node_id: NodeId) {
352        let node_row = self.state.node(node_id).row;
353        self.partial_solution.push(node_row as _);
354
355        let mut current_id = node_id;
356        loop {
357            self.cover(current_id);
358
359            current_id = self.state.node(current_id).right;
360            if current_id == node_id {
361                break;
362            }
363        }
364
365        self.step_stack.push(Step {
366            node_id,
367            backtracking: true,
368        });
369
370        if let Some(node_id) = self.choose_column() {
371            self.step_stack.push(Step {
372                node_id,
373                backtracking: false,
374            });
375        }
376    }
377
378    fn step_backward(&mut self, node_id: NodeId) {
379        self.partial_solution.pop();
380
381        let mut current_id = self.state.node(node_id).left;
382        loop {
383            self.uncover(current_id);
384
385            if current_id == node_id {
386                break;
387            }
388            current_id = self.state.node(current_id).left;
389        }
390
391        let node_down = self.state.node(node_id).down;
392        let node_header = self.state.node(node_id).header;
393
394        if node_down != node_header {
395            self.step_stack.push(Step {
396                node_id: node_down,
397                backtracking: false,
398            });
399        }
400    }
401}
402
403impl Iterator for Solver {
404    type Item = Vec<usize>;
405
406    fn next(&mut self) -> Option<Self::Item> {
407        while !self.is_completed() {
408            let step = self.step();
409
410            if step.is_some() {
411                return step;
412            }
413        }
414
415        None
416    }
417}
418
419#[cfg(test)]
420#[rustfmt::skip]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_basic_solve() {
426        // [x, x, -, -]
427        // [x, -, x, -]
428        // [-, x, -, x]
429        // [-, -, x, x]
430        // [x, x, x, -]
431        // [-, x, x, x]
432        //
433        //  __
434        // |# |
435        // |# |
436        //  ""
437        let solver = Solver::new(vec![
438            vec![0, 1],
439            vec![0, 2],
440            vec![1, 3],
441            vec![2, 3],
442            vec![0, 1, 2],
443            vec![1, 2, 3],
444        ], vec![0, 2]);
445
446        let solutions = solver.collect::<Vec<_>>();
447
448        assert_eq!(vec![vec![2]], solutions);
449    }
450}