1#![allow(clippy::not_unsafe_ptr_arg_deref)]
2
3mod base_node;
6
7use base_node::BaseNode;
8use core::{iter::once, ptr};
9use std::collections::VecDeque;
10
11#[derive(Debug)]
13pub struct Grid {
14 root: *mut Column,
16
17 arena: bumpalo::Bump,
18 columns: Vec<*mut Column>,
19
20 num_columns: usize,
21 max_row: usize,
22}
23
24impl Grid {
25 pub fn new(num_columns: usize, coordinates: impl IntoIterator<Item = (usize, usize)>) -> Self {
31 let arena = bumpalo::Bump::new();
32 let root = Column::new(&arena, 0);
33 let columns = once(root)
34 .chain((1..=num_columns).map(|idx| Column::new(&arena, idx)))
35 .collect::<Vec<_>>();
36
37 for idx in 0..columns.len() {
39 let next_idx = (idx + 1) % columns.len();
40 let column = columns[idx];
41 let next_column = columns[next_idx];
42
43 Column::add_right(column, next_column);
44 }
45
46 let mut grid = Grid {
47 root,
48 columns,
49 arena,
50 num_columns,
51 max_row: 0,
52 };
53
54 grid.add_all_coordinates(coordinates);
55
56 grid
57 }
58
59 fn add_all_coordinates(&mut self, coordinates: impl IntoIterator<Item = (usize, usize)>) {
60 let mut columns_data: Vec<Vec<_>> =
62 (0..(self.columns.len() - 1)).map(|_| Vec::new()).collect();
63
64 for (row, column) in coordinates {
65 debug_assert!(
66 row != 0 && column != 0,
67 "row or column should not equal zero [{:?}].",
68 (row, column)
69 );
70 debug_assert!(
71 column <= columns_data.len(),
72 "column idx should be in bounds [{:?}]",
73 column
74 );
75
76 columns_data[column - 1].push((row, column));
77
78 if self.max_row < row {
79 self.max_row = row
80 }
81 }
82
83 for column_data in &mut columns_data {
84 column_data.sort_unstable_by_key(|(k, _)| *k);
85 }
86
87 let mut nodes: Vec<VecDeque<*mut Node>> = columns_data
89 .into_iter()
90 .map(|column_data| {
91 column_data
92 .into_iter()
93 .map(|(row_idx, column_idx)| {
94 let column = self.columns[column_idx];
95
96 Node::new(&self.arena, row_idx, column)
97 })
98 .collect()
99 })
100 .collect();
101
102 for (node_column, column_header) in nodes.iter_mut().zip(self.columns.iter().skip(1)) {
105 let pair_it = node_column.iter().zip(node_column.iter().skip(1));
106 for (current_node, next_node) in pair_it {
107 BaseNode::add_below(current_node.cast(), next_node.cast());
108 }
109
110 if let Some(first) = node_column.front() {
112 BaseNode::add_below(column_header.cast(), first.cast());
113
114 if let Some(last) = node_column.back() {
115 BaseNode::add_above(column_header.cast(), last.cast());
116 }
117 }
118 }
119
120 let mut top_nodes: Vec<Option<(usize, *mut Node)>> = nodes
134 .iter_mut()
135 .map(|column_data| {
136 let node = column_data.pop_front();
137
138 node.map(|node| unsafe { (ptr::read(node).row, node) })
139 })
140 .collect();
141
142 let mut least_nodes = Vec::<(usize, *mut Node)>::with_capacity(top_nodes.len());
143
144 while top_nodes.iter().any(Option::is_some) {
145 let mut least_row = usize::MAX;
146
147 for (idx, row_node_pair) in top_nodes.iter().enumerate() {
149 if let Some((row, node)) = row_node_pair {
150 use core::cmp::Ordering;
151
152 match row.cmp(&least_row) {
153 Ordering::Equal => {
154 least_nodes.push((idx, *node));
155 }
156 Ordering::Less => {
157 least_nodes.clear();
158 least_row = *row;
159 least_nodes.push((idx, *node));
160 }
161 Ordering::Greater => {}
162 }
163 }
164 }
165
166 for (idx, (_, node)) in least_nodes.iter().enumerate() {
171 let next_node_idx = (idx + 1) % least_nodes.len();
172 let (_, next_node) = least_nodes[next_node_idx];
173
174 BaseNode::add_right(node.cast(), next_node.cast());
175 }
176
177 for (column_idx, _) in least_nodes.drain(..) {
180 top_nodes[column_idx] = nodes[column_idx]
181 .pop_front()
182 .map(|node| unsafe { (ptr::read(node).row, node) });
183 }
184 }
185 }
186
187 pub fn to_dense(&self) -> Box<[Box<[bool]>]> {
192 let seen_coords = self.uncovered_columns().flat_map(|column_ptr| {
193 let column_idx = Column::index(column_ptr);
194 Column::row_indices(column_ptr).map(move |row_idx| (row_idx, column_idx))
195 });
196
197 let mut output = vec![false; self.num_columns * self.max_row];
198
199 for (row_idx, column_idx) in seen_coords {
200 output[(row_idx - 1) * self.num_columns + (column_idx - 1)] = true
201 }
202
203 if self.num_columns == 0 {
204 debug_assert!(output.is_empty());
205
206 vec![].into_boxed_slice()
207 } else {
208 output
209 .as_slice()
210 .chunks(self.num_columns)
211 .map(Box::<[_]>::from)
212 .collect()
213 }
214 }
215
216 pub fn uncovered_columns(&self) -> impl Iterator<Item = *const Column> {
218 base_node::iter::right(self.root.cast(), Some(self.root.cast()))
219 .map(|base_ptr| base_ptr.cast::<Column>())
220 }
221
222 pub fn uncovered_columns_mut(&mut self) -> impl Iterator<Item = *mut Column> {
224 base_node::iter::right_mut(self.root.cast(), Some(self.root.cast()))
225 .map(|base_ptr| base_ptr.cast::<Column>())
226 }
227
228 pub fn all_columns_mut(&mut self) -> impl DoubleEndedIterator<Item = *mut Column> + '_ {
231 self.columns
232 .iter()
233 .copied()
234 .skip(1)
236 }
237
238 pub fn get_column(&self, index: usize) -> Option<*const Column> {
240 self.columns
241 .get(index)
242 .copied()
243 .map(|column_ptr| column_ptr as *const _)
244 }
245
246 pub fn get_column_mut(&mut self, index: usize) -> Option<*mut Column> {
248 self.columns.get(index).copied()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 unsafe {
254 let column = ptr::read(self.root);
255
256 (column.base.right as *const _) == self.root.cast()
257 }
258 }
259}
260
261#[derive(Debug, PartialEq, Eq, Hash)]
263#[repr(C)]
264pub struct Node {
265 base: BaseNode,
266
267 row: usize,
268 column: *mut Column,
269}
270
271impl Node {
272 fn new(arena: &bumpalo::Bump, row: usize, column: *mut Column) -> *mut Self {
273 Column::increment_size(column);
274
275 let node = arena.alloc(Node {
276 base: BaseNode::new(),
277
278 row,
279 column,
280 });
281
282 node.base.set_self_ptr();
283
284 node
285 }
286
287 pub fn cover_row(self_ptr: *mut Node) {
291 base_node::iter::right_mut(self_ptr.cast(), Some(self_ptr.cast())).for_each(
294 |base_ptr| unsafe {
295 let node = ptr::read(base_ptr.cast::<Node>());
296
297 Column::decrement_size(node.column);
298 BaseNode::cover_vertical(base_ptr);
299 },
300 )
301 }
302
303 pub fn uncover_row(self_ptr: *mut Self) {
307 let base_ptr = self_ptr.cast::<BaseNode>();
308
309 base_node::iter::left_mut(base_ptr, Some(base_ptr)).for_each(|base_ptr| unsafe {
310 let node = ptr::read(base_ptr.cast::<Node>());
311
312 Column::increment_size(node.column);
313 BaseNode::uncover_vertical(base_ptr);
314 })
315 }
316
317 pub fn row_index(self_ptr: *const Self) -> usize {
319 unsafe { ptr::read(self_ptr).row }
320 }
321
322 pub fn column_index(self_ptr: *const Self) -> usize {
324 unsafe {
325 let node = ptr::read(self_ptr);
326 let column = ptr::read(node.column);
327
328 column.index
329 }
330 }
331
332 pub fn column_ptr(self_ptr: *const Self) -> *mut Column {
334 unsafe {
335 let node = ptr::read(self_ptr);
336
337 node.column
338 }
339 }
340
341 pub fn neighbors(self_ptr: *const Self) -> impl Iterator<Item = *const Node> {
343 base_node::iter::left(self_ptr.cast(), None).map(|base_ptr| base_ptr.cast())
344 }
345}
346
347#[derive(Debug, PartialEq, Eq, Hash)]
349#[repr(C)]
350pub struct Column {
351 base: BaseNode,
352
353 size: usize,
354 index: usize,
355 is_covered: bool,
356}
357
358impl Column {
359 fn new(arena: &bumpalo::Bump, index: usize) -> *mut Self {
360 let column = arena.alloc(Column {
361 base: BaseNode::new(),
362 size: 0,
363 is_covered: false,
364 index,
365 });
366
367 column.base.set_self_ptr();
368
369 column
370 }
371
372 fn increment_size(self_ptr: *mut Self) {
373 unsafe {
374 let mut column = ptr::read(self_ptr);
375
376 column.size += 1;
377
378 ptr::write(self_ptr, column);
379 }
380 }
381
382 fn decrement_size(self_ptr: *mut Self) {
383 unsafe {
384 let mut column = ptr::read(self_ptr);
385
386 column.size -= 1;
387
388 ptr::write(self_ptr, column);
389 }
390 }
391
392 pub fn cover(self_ptr: *mut Self) {
394 let mut column = unsafe { ptr::read(self_ptr) };
395 assert!(!column.is_covered);
396
397 let base_ptr = self_ptr.cast::<BaseNode>();
398
399 BaseNode::cover_horizontal(base_ptr);
400
401 base_node::iter::down_mut(base_ptr, Some(base_ptr))
402 .for_each(|base_ptr| Node::cover_row(base_ptr.cast()));
403
404 column.is_covered = true;
405 unsafe {
406 ptr::write(self_ptr, column);
407 }
408 }
409
410 pub fn uncover(self_ptr: *mut Self) {
412 let mut column = unsafe { ptr::read(self_ptr) };
413 assert!(column.is_covered);
414
415 let base_ptr = self_ptr.cast::<BaseNode>();
416
417 base_node::iter::up_mut(base_ptr, Some(base_ptr))
418 .for_each(|base_ptr| Node::uncover_row(base_ptr.cast()));
419
420 BaseNode::uncover_horizontal(base_ptr);
421
422 column.is_covered = false;
423 unsafe {
424 ptr::write(self_ptr, column);
425 }
426 }
427
428 fn add_right(self_ptr: *mut Self, neighbor_ptr: *mut Column) {
429 BaseNode::add_right(self_ptr.cast(), neighbor_ptr.cast());
430 }
431
432 pub fn is_empty(self_ptr: *const Self) -> bool {
434 unsafe {
435 let column = ptr::read(self_ptr);
436
437 let empty = (column.base.down as *const _) == self_ptr;
438
439 debug_assert!(
440 !empty && Self::size(self_ptr) == 0,
441 "The size should be tracked accurately."
442 );
443
444 empty
445 }
446 }
447
448 pub fn row_indices(self_ptr: *const Self) -> impl Iterator<Item = usize> {
451 Column::rows(self_ptr).map(|node_ptr| unsafe { ptr::read(node_ptr).row })
452 }
453
454 pub fn rows(self_ptr: *const Self) -> impl Iterator<Item = *const Node> {
456 base_node::iter::down(self_ptr.cast(), Some(self_ptr.cast()))
457 .map(|base_ptr| base_ptr.cast())
458 }
459
460 pub fn nodes_mut(self_ptr: *mut Self) -> impl Iterator<Item = *mut Node> {
463 base_node::iter::down_mut(self_ptr.cast(), Some(self_ptr.cast()))
464 .map(|base_ptr| base_ptr.cast())
465 }
466
467 #[inline]
469 pub fn index(self_ptr: *const Self) -> usize {
470 unsafe { ptr::read(self_ptr).index }
471 }
472
473 #[inline]
475 pub fn size(self_ptr: *const Self) -> usize {
476 unsafe { ptr::read(self_ptr).size }
477 }
478}
479
480#[cfg(test)]
485pub fn to_string(grid: &Grid) -> String {
486 use std::fmt::Write;
487
488 let mut output = String::new();
489 let dense = grid.to_dense();
490
491 if dense.is_empty() {
492 writeln!(&mut output, "Empty!").unwrap();
493
494 return output;
495 }
496
497 for row in dense.iter() {
498 writeln!(
499 &mut output,
500 "{:?}",
501 row.iter()
502 .map(|yes| if *yes { 1 } else { 0 })
503 .collect::<Vec<_>>()
504 )
505 .unwrap();
506 }
507
508 output
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 #[rustfmt::skip]
517 fn create_a_small_grid() {
518 let grid = Grid::new(4, vec![(1, 1), (1, 4), (2, 2), (3, 3), (4, 1), (4, 4)]);
519
520 assert_eq!(
521 grid.to_dense(),
522 [
523 true, false, false, true,
524 false, true, false, false,
525 false, false, true, false,
526 true, false, false, true
527 ]
528 .chunks(4)
529 .map(Box::<[_]>::from)
530 .collect()
531 );
532 }
533
534 #[test]
535 #[rustfmt::skip]
536 fn create_weird_grids() {
537 let thin_grid = Grid::new(1, vec![
538 (1, 1),
539 (2, 1),
540 (3, 1),
541 (5, 1),
543 (8, 1)
545 ]);
546
547 assert_eq!(
551 thin_grid.to_dense(),
552 [
553 true,
554 true,
555 true,
556 false,
557 true,
558 false,
559 false,
560 true
561 ]
562 .chunks(1)
563 .map(Box::<[_]>::from)
564 .collect()
565 );
566 assert!(!thin_grid.is_empty());
567
568 let very_thin_grid = Grid::new(0, vec![]);
569
570 assert_eq!(very_thin_grid.to_dense(), vec![].into_boxed_slice());
571 assert!(very_thin_grid.is_empty());
572 }
573
574 #[test]
575 #[rustfmt::skip]
576 fn cover_uncover_column() {
577 let mut grid = Grid::new(4, vec![(1, 1), (1, 4), (2, 2), (3, 3), (4, 1), (4, 4)]);
578
579 Column::cover(grid.all_columns_mut().nth(3).unwrap());
581
582 assert!(grid
584 .uncovered_columns()
585 .map(|column_ptr| unsafe { ptr::read(column_ptr).index })
586 .eq(1..=3));
587 assert_eq!(
588 grid.to_dense(),
589 [
590 false, false, false, false,
591 false, true, false, false,
592 false, false, true, false,
593 false, false, false, false
594 ]
595 .chunks(4)
596 .map(Box::<[_]>::from)
597 .collect()
598 );
599
600 Column::uncover(grid.all_columns_mut().nth(3).unwrap());
602
603 assert!(grid
605 .uncovered_columns()
606 .map(|column_ptr| unsafe { ptr::read(column_ptr).index })
607 .eq(1..=4));
608 assert_eq!(
609 grid.to_dense(),
610 [
611 true, false, false, true,
612 false, true, false, false,
613 false, false, true, false,
614 true, false, false, true
615 ]
616 .chunks(4)
617 .map(Box::<[_]>::from)
618 .collect()
619 );
620 }
621
622 #[test]
623 #[rustfmt::skip]
624 fn cover_uncover_all() {
625 let mut grid = Grid::new(4, vec![
626 (1, 1), (1, 4),
627 (2, 2),
628 (3, 3),
629 (4, 1), (4, 4)
630 ]);
631
632 for column_ptr in grid.all_columns_mut() {
634 Column::cover(column_ptr)
635 }
636
637 assert!(grid.uncovered_columns().map(|column_ptr| unsafe { ptr::read(column_ptr).index }).eq(0..0));
639 assert_eq!(
640 grid.to_dense(),
641 [
642 false, false, false, false,
643 false, false, false, false,
644 false, false, false, false,
645 false, false, false, false
646 ]
647 .chunks(4)
648 .map(Box::<[_]>::from)
649 .collect()
650 );
651 assert!(grid.is_empty());
652
653 for column_ptr in grid.all_columns_mut().rev() {
655 Column::uncover(column_ptr)
656 }
657
658 assert!(grid.uncovered_columns().map(|column_ptr| unsafe { ptr::read(column_ptr).index }).eq(1..=4));
660 assert_eq!(
661 grid.to_dense(),
662 [
663 true, false, false, true,
664 false, true, false, false,
665 false, false, true, false,
666 true, false, false, true
667 ]
668 .chunks(4)
669 .map(Box::<[_]>::from)
670 .collect()
671 );
672 assert!(!grid.is_empty());
673 }
674
675 #[test]
676 #[rustfmt::skip]
677 fn latin_square_cover_1() {
678 let mut grid = Grid::new(6, vec![
683 (1, 1), (1, 5),
684 (2, 2), (2, 3), (2, 5),
685 (3, 1), (3, 4), (3, 6),
686 (4, 2), (4, 6),
687 ]);
688
689 assert_eq!(
690 grid.to_dense(),
691 [
692 true, false, false, false, true, false,
693 false, true, true, false, true, false,
694 true, false, false, true, false, true,
695 false, true, false, false, false, true,
696 ]
697 .chunks(6)
698 .map(Box::<[_]>::from)
699 .collect()
700 );
701 assert!(!grid.is_empty());
702
703 Column::cover(grid.get_column_mut(2).unwrap());
704 Column::cover(grid.get_column_mut(3).unwrap());
705 Column::cover(grid.get_column_mut(5).unwrap());
706
707 assert_eq!(
708 grid.to_dense(),
709 [
710 false, false, false, false, false, false,
711 false, false, false, false, false, false,
712 true, false, false, true, false, true,
713 false, false, false, false, false, false,
714 ]
715 .chunks(6)
716 .map(Box::<[_]>::from)
717 .collect()
718 );
719 }
720}