dancing_links/
grid.rs

1#![allow(clippy::not_unsafe_ptr_arg_deref)]
2
3//! Dancing links `Grid` implementation for use in the `Solver`.
4
5mod base_node;
6
7use base_node::BaseNode;
8use core::{iter::once, ptr};
9use std::collections::VecDeque;
10
11/// Dancing links grid, support efficient removal of rows and columns.
12#[derive(Debug)]
13pub struct Grid {
14    // This node only left-right neighbors, no children
15    root: *mut Column,
16
17    arena: bumpalo::Bump,
18    columns: Vec<*mut Column>,
19
20    num_columns: usize,
21    max_row: usize,
22}
23
24impl Grid {
25    /// Create a new grid with a specified number of columns, and the given
26    /// coordinates filled.
27    ///
28    /// Rows and columns are based 1 indexed for this grid, matching the
29    /// indexing notation for matrices in general.
30    pub fn new(num_columns: usize, coordinates: impl IntoIterator<Item = (usize, usize)>) -> Self {
31        let arena = bumpalo::Bump::new();
32        let root = Column::new(&arena, 0);
33        let columns = once(root)
34            .chain((1..=num_columns).map(|idx| Column::new(&arena, idx)))
35            .collect::<Vec<_>>();
36
37        // Chain all the columns together, including the sentinel root column.
38        for idx in 0..columns.len() {
39            let next_idx = (idx + 1) % columns.len();
40            let column = columns[idx];
41            let next_column = columns[next_idx];
42
43            Column::add_right(column, next_column);
44        }
45
46        let mut grid = Grid {
47            root,
48            columns,
49            arena,
50            num_columns,
51            max_row: 0,
52        };
53
54        grid.add_all_coordinates(coordinates);
55
56        grid
57    }
58
59    fn add_all_coordinates(&mut self, coordinates: impl IntoIterator<Item = (usize, usize)>) {
60        // Deduct one for the sentinel column
61        let mut columns_data: Vec<Vec<_>> =
62            (0..(self.columns.len() - 1)).map(|_| Vec::new()).collect();
63
64        for (row, column) in coordinates {
65            debug_assert!(
66                row != 0 && column != 0,
67                "row or column should not equal zero [{:?}].",
68                (row, column)
69            );
70            debug_assert!(
71                column <= columns_data.len(),
72                "column idx should be in bounds [{:?}]",
73                column
74            );
75
76            columns_data[column - 1].push((row, column));
77
78            if self.max_row < row {
79                self.max_row = row
80            }
81        }
82
83        for column_data in &mut columns_data {
84            column_data.sort_unstable_by_key(|(k, _)| *k);
85        }
86
87        // Map all the data into nodes
88        let mut nodes: Vec<VecDeque<*mut Node>> = columns_data
89            .into_iter()
90            .map(|column_data| {
91                column_data
92                    .into_iter()
93                    .map(|(row_idx, column_idx)| {
94                        let column = self.columns[column_idx];
95
96                        Node::new(&self.arena, row_idx, column)
97                    })
98                    .collect()
99            })
100            .collect();
101
102        // Then, add all the vertical connections, without wrapping around. Skip the
103        // first (sentinel) column.
104        for (node_column, column_header) in nodes.iter_mut().zip(self.columns.iter().skip(1)) {
105            let pair_it = node_column.iter().zip(node_column.iter().skip(1));
106            for (current_node, next_node) in pair_it {
107                BaseNode::add_below(current_node.cast(), next_node.cast());
108            }
109
110            // Connect first and last to header
111            if let Some(first) = node_column.front() {
112                BaseNode::add_below(column_header.cast(), first.cast());
113
114                if let Some(last) = node_column.back() {
115                    BaseNode::add_above(column_header.cast(), last.cast());
116                }
117            }
118        }
119
120        // Then, add all horizontal connections, with wrap around
121        //
122        // To do this we need to select all nodes which have the same row value
123        // and then chain them together. The column data is in sorted order from
124        // before.
125        //
126        // For each column, collect a list with the top (least row value) node. Then,
127        // for each value in the list, collect a subset that contains all the nodes with
128        // the same least row value. They should also be in column order. This
129        // collection will be linked together with wraparound. Then all those nodes that
130        // were selected for the least subset will be replaced from the list with the
131        // next relevant node from the column.
132
133        let mut top_nodes: Vec<Option<(usize, *mut Node)>> = nodes
134            .iter_mut()
135            .map(|column_data| {
136                let node = column_data.pop_front();
137
138                node.map(|node| unsafe { (ptr::read(node).row, node) })
139            })
140            .collect();
141
142        let mut least_nodes = Vec::<(usize, *mut Node)>::with_capacity(top_nodes.len());
143
144        while top_nodes.iter().any(Option::is_some) {
145            let mut least_row = usize::MAX;
146
147            // Select the subcollection of least row nodes
148            for (idx, row_node_pair) in top_nodes.iter().enumerate() {
149                if let Some((row, node)) = row_node_pair {
150                    use core::cmp::Ordering;
151
152                    match row.cmp(&least_row) {
153                        Ordering::Equal => {
154                            least_nodes.push((idx, *node));
155                        }
156                        Ordering::Less => {
157                            least_nodes.clear();
158                            least_row = *row;
159                            least_nodes.push((idx, *node));
160                        }
161                        Ordering::Greater => {}
162                    }
163                }
164            }
165
166            // Link all the least row nodes together
167            //
168            // This is fine for the case of (least_nodes.len() == 1) bc all nodes started
169            // already linked to themselves.
170            for (idx, (_, node)) in least_nodes.iter().enumerate() {
171                let next_node_idx = (idx + 1) % least_nodes.len();
172                let (_, next_node) = least_nodes[next_node_idx];
173
174                BaseNode::add_right(node.cast(), next_node.cast());
175            }
176
177            // Replace the least row nodes with the next values from their respective
178            // columns.
179            for (column_idx, _) in least_nodes.drain(..) {
180                top_nodes[column_idx] = nodes[column_idx]
181                    .pop_front()
182                    .map(|node| unsafe { (ptr::read(node).row, node) });
183            }
184        }
185    }
186
187    /// Convert the grid to a dense representation.
188    ///
189    /// This takes the original size of the grid, and only put `true` values for
190    /// locations that are still present in the grid (not covered).
191    pub fn to_dense(&self) -> Box<[Box<[bool]>]> {
192        let seen_coords = self.uncovered_columns().flat_map(|column_ptr| {
193            let column_idx = Column::index(column_ptr);
194            Column::row_indices(column_ptr).map(move |row_idx| (row_idx, column_idx))
195        });
196
197        let mut output = vec![false; self.num_columns * self.max_row];
198
199        for (row_idx, column_idx) in seen_coords {
200            output[(row_idx - 1) * self.num_columns + (column_idx - 1)] = true
201        }
202
203        if self.num_columns == 0 {
204            debug_assert!(output.is_empty());
205
206            vec![].into_boxed_slice()
207        } else {
208            output
209                .as_slice()
210                .chunks(self.num_columns)
211                .map(Box::<[_]>::from)
212                .collect()
213        }
214    }
215
216    /// Return an iterator of pointers to columns that are uncovered.
217    pub fn uncovered_columns(&self) -> impl Iterator<Item = *const Column> {
218        base_node::iter::right(self.root.cast(), Some(self.root.cast()))
219            .map(|base_ptr| base_ptr.cast::<Column>())
220    }
221
222    /// Return an iterator of mut pointers to columns that are uncovered.
223    pub fn uncovered_columns_mut(&mut self) -> impl Iterator<Item = *mut Column> {
224        base_node::iter::right_mut(self.root.cast(), Some(self.root.cast()))
225            .map(|base_ptr| base_ptr.cast::<Column>())
226    }
227
228    /// Return an iterator over all columns that are in the grid (covered and
229    /// uncovered).
230    pub fn all_columns_mut(&mut self) -> impl DoubleEndedIterator<Item = *mut Column> + '_ {
231        self.columns
232            .iter()
233            .copied()
234            // Skip the sentinel
235            .skip(1)
236    }
237
238    /// Return a pointer to a specific `Column`, if it exists.
239    pub fn get_column(&self, index: usize) -> Option<*const Column> {
240        self.columns
241            .get(index)
242            .copied()
243            .map(|column_ptr| column_ptr as *const _)
244    }
245
246    /// Return a mut pointer to a specific `Column`, if it exists.
247    pub fn get_column_mut(&mut self, index: usize) -> Option<*mut Column> {
248        self.columns.get(index).copied()
249    }
250
251    /// Return true if there are no uncovered columns in the grid.
252    pub fn is_empty(&self) -> bool {
253        unsafe {
254            let column = ptr::read(self.root);
255
256            (column.base.right as *const _) == self.root.cast()
257        }
258    }
259}
260
261/// A coordinate inside of a `Grid`.
262#[derive(Debug, PartialEq, Eq, Hash)]
263#[repr(C)]
264pub struct Node {
265    base: BaseNode,
266
267    row: usize,
268    column: *mut Column,
269}
270
271impl Node {
272    fn new(arena: &bumpalo::Bump, row: usize, column: *mut Column) -> *mut Self {
273        Column::increment_size(column);
274
275        let node = arena.alloc(Node {
276            base: BaseNode::new(),
277
278            row,
279            column,
280        });
281
282        node.base.set_self_ptr();
283
284        node
285    }
286
287    /// Cover every `Node` that is horizontally adjacent to this `Node`.
288    ///
289    /// This `Node` is not covered.
290    pub fn cover_row(self_ptr: *mut Node) {
291        // Skip over the originating node in the row so that it can be recovered from
292        // the column.
293        base_node::iter::right_mut(self_ptr.cast(), Some(self_ptr.cast())).for_each(
294            |base_ptr| unsafe {
295                let node = ptr::read(base_ptr.cast::<Node>());
296
297                Column::decrement_size(node.column);
298                BaseNode::cover_vertical(base_ptr);
299            },
300        )
301    }
302
303    /// Uncover every `Node` that is horizontally adjacent to this `Node`.
304    ///
305    /// This `Node` is not uncovered.
306    pub fn uncover_row(self_ptr: *mut Self) {
307        let base_ptr = self_ptr.cast::<BaseNode>();
308
309        base_node::iter::left_mut(base_ptr, Some(base_ptr)).for_each(|base_ptr| unsafe {
310            let node = ptr::read(base_ptr.cast::<Node>());
311
312            Column::increment_size(node.column);
313            BaseNode::uncover_vertical(base_ptr);
314        })
315    }
316
317    /// Return the row index of this `Node`.
318    pub fn row_index(self_ptr: *const Self) -> usize {
319        unsafe { ptr::read(self_ptr).row }
320    }
321
322    /// Return the column index of this `Node`.
323    pub fn column_index(self_ptr: *const Self) -> usize {
324        unsafe {
325            let node = ptr::read(self_ptr);
326            let column = ptr::read(node.column);
327
328            column.index
329        }
330    }
331
332    /// Return a mut pointer to the `Column` of this `Node`.
333    pub fn column_ptr(self_ptr: *const Self) -> *mut Column {
334        unsafe {
335            let node = ptr::read(self_ptr);
336
337            node.column
338        }
339    }
340
341    /// Return an iterator over all `Node`s that are adjacent to this `Node`.
342    pub fn neighbors(self_ptr: *const Self) -> impl Iterator<Item = *const Node> {
343        base_node::iter::left(self_ptr.cast(), None).map(|base_ptr| base_ptr.cast())
344    }
345}
346
347/// A column inside of a `Grid`.
348#[derive(Debug, PartialEq, Eq, Hash)]
349#[repr(C)]
350pub struct Column {
351    base: BaseNode,
352
353    size: usize,
354    index: usize,
355    is_covered: bool,
356}
357
358impl Column {
359    fn new(arena: &bumpalo::Bump, index: usize) -> *mut Self {
360        let column = arena.alloc(Column {
361            base: BaseNode::new(),
362            size: 0,
363            is_covered: false,
364            index,
365        });
366
367        column.base.set_self_ptr();
368
369        column
370    }
371
372    fn increment_size(self_ptr: *mut Self) {
373        unsafe {
374            let mut column = ptr::read(self_ptr);
375
376            column.size += 1;
377
378            ptr::write(self_ptr, column);
379        }
380    }
381
382    fn decrement_size(self_ptr: *mut Self) {
383        unsafe {
384            let mut column = ptr::read(self_ptr);
385
386            column.size -= 1;
387
388            ptr::write(self_ptr, column);
389        }
390    }
391
392    /// Cover entire column, and any rows that that appear in this column.
393    pub fn cover(self_ptr: *mut Self) {
394        let mut column = unsafe { ptr::read(self_ptr) };
395        assert!(!column.is_covered);
396
397        let base_ptr = self_ptr.cast::<BaseNode>();
398
399        BaseNode::cover_horizontal(base_ptr);
400
401        base_node::iter::down_mut(base_ptr, Some(base_ptr))
402            .for_each(|base_ptr| Node::cover_row(base_ptr.cast()));
403
404        column.is_covered = true;
405        unsafe {
406            ptr::write(self_ptr, column);
407        }
408    }
409
410    /// Uncover entire column, and any rows that appear in this column.
411    pub fn uncover(self_ptr: *mut Self) {
412        let mut column = unsafe { ptr::read(self_ptr) };
413        assert!(column.is_covered);
414
415        let base_ptr = self_ptr.cast::<BaseNode>();
416
417        base_node::iter::up_mut(base_ptr, Some(base_ptr))
418            .for_each(|base_ptr| Node::uncover_row(base_ptr.cast()));
419
420        BaseNode::uncover_horizontal(base_ptr);
421
422        column.is_covered = false;
423        unsafe {
424            ptr::write(self_ptr, column);
425        }
426    }
427
428    fn add_right(self_ptr: *mut Self, neighbor_ptr: *mut Column) {
429        BaseNode::add_right(self_ptr.cast(), neighbor_ptr.cast());
430    }
431
432    /// Return true if there are no uncovered `Node`s in this column.
433    pub fn is_empty(self_ptr: *const Self) -> bool {
434        unsafe {
435            let column = ptr::read(self_ptr);
436
437            let empty = (column.base.down as *const _) == self_ptr;
438
439            debug_assert!(
440                !empty && Self::size(self_ptr) == 0,
441                "The size should be tracked accurately."
442            );
443
444            empty
445        }
446    }
447
448    /// Return an iterator over the row indices of all uncovered `Node`s in this
449    /// column.
450    pub fn row_indices(self_ptr: *const Self) -> impl Iterator<Item = usize> {
451        Column::rows(self_ptr).map(|node_ptr| unsafe { ptr::read(node_ptr).row })
452    }
453
454    /// Return an iterator of pointers to all uncovered `Node`s in this column.
455    pub fn rows(self_ptr: *const Self) -> impl Iterator<Item = *const Node> {
456        base_node::iter::down(self_ptr.cast(), Some(self_ptr.cast()))
457            .map(|base_ptr| base_ptr.cast())
458    }
459
460    /// Return an iterator of mut pointers to all uncovered `Node`s in this
461    /// column.
462    pub fn nodes_mut(self_ptr: *mut Self) -> impl Iterator<Item = *mut Node> {
463        base_node::iter::down_mut(self_ptr.cast(), Some(self_ptr.cast()))
464            .map(|base_ptr| base_ptr.cast())
465    }
466
467    /// Return the column index.
468    #[inline]
469    pub fn index(self_ptr: *const Self) -> usize {
470        unsafe { ptr::read(self_ptr).index }
471    }
472
473    /// Return the number of uncovered nodes in this column.
474    #[inline]
475    pub fn size(self_ptr: *const Self) -> usize {
476        unsafe { ptr::read(self_ptr).size }
477    }
478}
479
480/// This function will convert a grid to a string representation useful for
481/// debugging
482///
483/// This should only be used for test functions.
484#[cfg(test)]
485pub fn to_string(grid: &Grid) -> String {
486    use std::fmt::Write;
487
488    let mut output = String::new();
489    let dense = grid.to_dense();
490
491    if dense.is_empty() {
492        writeln!(&mut output, "Empty!").unwrap();
493
494        return output;
495    }
496
497    for row in dense.iter() {
498        writeln!(
499            &mut output,
500            "{:?}",
501            row.iter()
502                .map(|yes| if *yes { 1 } else { 0 })
503                .collect::<Vec<_>>()
504        )
505        .unwrap();
506    }
507
508    output
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    #[rustfmt::skip]
517    fn create_a_small_grid() {
518        let grid = Grid::new(4, vec![(1, 1), (1, 4), (2, 2), (3, 3), (4, 1), (4, 4)]);
519
520        assert_eq!(
521            grid.to_dense(),
522            [
523                true, false, false, true,
524                false, true, false, false,
525                false, false, true, false,
526                true, false, false, true
527            ]
528            .chunks(4)
529            .map(Box::<[_]>::from)
530            .collect()
531        );
532    }
533
534    #[test]
535    #[rustfmt::skip]
536    fn create_weird_grids() {
537        let thin_grid = Grid::new(1, vec![
538            (1, 1),
539            (2, 1),
540            (3, 1),
541            // skip 4
542            (5, 1),
543            // skip 6, 7
544            (8, 1)
545        ]);
546
547        // The reasoning behind having the skipped rows not show up in
548        // the dense output is that those rows are not present at all in
549        // the
550        assert_eq!(
551            thin_grid.to_dense(),
552            [
553                true,
554                true,
555                true,
556                false,
557                true,
558                false,
559                false,
560                true
561            ]
562            .chunks(1)
563            .map(Box::<[_]>::from)
564            .collect()
565        );
566        assert!(!thin_grid.is_empty());
567
568        let very_thin_grid = Grid::new(0, vec![]);
569
570        assert_eq!(very_thin_grid.to_dense(), vec![].into_boxed_slice());
571        assert!(very_thin_grid.is_empty());
572    }
573
574    #[test]
575    #[rustfmt::skip]
576    fn cover_uncover_column() {
577        let mut grid = Grid::new(4, vec![(1, 1), (1, 4), (2, 2), (3, 3), (4, 1), (4, 4)]);
578
579        // mutate the grid
580        Column::cover(grid.all_columns_mut().nth(3).unwrap());
581
582        // Check remaining columns
583        assert!(grid
584            .uncovered_columns()
585            .map(|column_ptr| unsafe { ptr::read(column_ptr).index })
586            .eq(1..=3));
587        assert_eq!(
588            grid.to_dense(),
589            [
590                false, false, false, false,
591                false, true, false, false,
592                false, false, true, false,
593                false, false, false, false
594            ]
595            .chunks(4)
596            .map(Box::<[_]>::from)
597            .collect()
598        );
599
600        // mutate the grid
601        Column::uncover(grid.all_columns_mut().nth(3).unwrap());
602
603        // Check remaining columns
604        assert!(grid
605            .uncovered_columns()
606            .map(|column_ptr| unsafe { ptr::read(column_ptr).index })
607            .eq(1..=4));
608        assert_eq!(
609            grid.to_dense(),
610            [
611                true, false, false, true,
612                false, true, false, false,
613                false, false, true, false,
614                true, false, false, true
615            ]
616            .chunks(4)
617            .map(Box::<[_]>::from)
618            .collect()
619        );
620    }
621
622    #[test]
623    #[rustfmt::skip]
624    fn cover_uncover_all() {
625        let mut grid = Grid::new(4, vec![
626            (1, 1),                 (1, 4),
627                    (2, 2),
628                            (3, 3),
629            (4, 1),                 (4, 4)
630        ]);
631
632        // mutate the grid
633        for column_ptr in grid.all_columns_mut() {
634            Column::cover(column_ptr)
635        }
636
637        // Check remaining columns
638        assert!(grid.uncovered_columns().map(|column_ptr| unsafe { ptr::read(column_ptr).index }).eq(0..0));
639        assert_eq!(
640            grid.to_dense(),
641            [
642                false, false, false, false,
643                false, false, false, false,
644                false, false, false, false,
645                false, false, false, false
646            ]
647            .chunks(4)
648            .map(Box::<[_]>::from)
649            .collect()
650        );
651        assert!(grid.is_empty());
652
653        // mutate the grid
654        for column_ptr in grid.all_columns_mut().rev() {
655            Column::uncover(column_ptr)
656        }
657
658        // Check remaining columns
659        assert!(grid.uncovered_columns().map(|column_ptr| unsafe { ptr::read(column_ptr).index }).eq(1..=4));
660        assert_eq!(
661            grid.to_dense(),
662            [
663                true, false, false, true,
664                false, true, false, false,
665                false, false, true, false,
666                true, false, false, true
667            ]
668            .chunks(4)
669            .map(Box::<[_]>::from)
670            .collect()
671        );
672        assert!(!grid.is_empty());
673    }
674
675    #[test]
676    #[rustfmt::skip]
677    fn latin_square_cover_1() {
678        // [1, 0, 0, 0, 1, 0]
679        // [0, 1, 1, 0, 1, 0]
680        // [1, 0, 0, 1, 0, 1]
681        // [0, 1, 0, 0, 0, 1]
682        let mut grid = Grid::new(6, vec![
683            (1, 1),                         (1, 5),
684                    (2, 2), (2, 3),         (2, 5),
685            (3, 1),                 (3, 4),         (3, 6),
686                    (4, 2),                         (4, 6),
687        ]);
688
689        assert_eq!(
690            grid.to_dense(),
691            [
692                true, false, false, false, true, false,
693                false, true, true, false, true, false,
694                true, false, false, true, false, true,
695                false, true, false, false, false, true,
696            ]
697            .chunks(6)
698            .map(Box::<[_]>::from)
699            .collect()
700        );
701        assert!(!grid.is_empty());
702
703        Column::cover(grid.get_column_mut(2).unwrap());
704        Column::cover(grid.get_column_mut(3).unwrap());
705        Column::cover(grid.get_column_mut(5).unwrap());
706
707        assert_eq!(
708            grid.to_dense(),
709            [
710                false, false, false, false, false, false,
711                false, false, false, false, false, false,
712                true, false, false, true, false, true,
713                false, false, false, false, false, false,
714            ]
715            .chunks(6)
716            .map(Box::<[_]>::from)
717            .collect()
718        );
719    }
720}