1use crate::error::SpaceError;
4use crate::region::{BoundingShape, RegionPlan, RegionSpec};
5use crate::space::Space;
6use murk_core::{Coord, SpaceInstanceId};
7use smallvec::{smallvec, SmallVec};
8
9const HEX_OFFSETS: [(i32, i32); 6] = [
11 (1, 0), (1, -1), (0, -1), (-1, 0), (-1, 1), (0, 1), ];
18
19#[derive(Debug, Clone)]
55pub struct Hex2D {
56 rows: u32,
57 cols: u32,
58 instance_id: SpaceInstanceId,
59}
60
61impl Hex2D {
62 pub const MAX_DIM: u32 = i32::MAX as u32;
64
65 pub fn new(rows: u32, cols: u32) -> Result<Self, SpaceError> {
70 if rows == 0 || cols == 0 {
71 return Err(SpaceError::EmptySpace);
72 }
73 if rows > Self::MAX_DIM {
74 return Err(SpaceError::DimensionTooLarge {
75 name: "rows",
76 value: rows,
77 max: Self::MAX_DIM,
78 });
79 }
80 if cols > Self::MAX_DIM {
81 return Err(SpaceError::DimensionTooLarge {
82 name: "cols",
83 value: cols,
84 max: Self::MAX_DIM,
85 });
86 }
87 Ok(Self {
88 rows,
89 cols,
90 instance_id: SpaceInstanceId::next(),
91 })
92 }
93
94 pub fn rows(&self) -> u32 {
96 self.rows
97 }
98
99 pub fn cols(&self) -> u32 {
101 self.cols
102 }
103
104 pub fn is_empty(&self) -> bool {
106 false
107 }
108
109 fn check_bounds(&self, coord: &Coord) -> Result<(i32, i32), SpaceError> {
111 if coord.len() != 2 {
112 return Err(SpaceError::CoordOutOfBounds {
113 coord: coord.clone(),
114 bounds: format!("expected 2D coordinate, got {}D", coord.len()),
115 });
116 }
117 let q = coord[0];
118 let r = coord[1];
119 if q < 0 || q >= self.cols as i32 || r < 0 || r >= self.rows as i32 {
120 return Err(SpaceError::CoordOutOfBounds {
121 coord: coord.clone(),
122 bounds: format!("q in [0, {}), r in [0, {})", self.cols, self.rows),
123 });
124 }
125 Ok((q, r))
126 }
127
128 fn neighbours_qr(&self, q: i32, r: i32) -> SmallVec<[(i32, i32); 6]> {
130 let mut result = SmallVec::new();
131 for (dq, dr) in HEX_OFFSETS {
132 let nq = q + dq;
133 let nr = r + dr;
134 if nq >= 0 && nq < self.cols as i32 && nr >= 0 && nr < self.rows as i32 {
135 result.push((nq, nr));
136 }
137 }
138 result
139 }
140
141 fn cube_distance(q1: i32, r1: i32, q2: i32, r2: i32) -> i32 {
143 let dq = (q1 - q2).abs();
144 let dr = (r1 - r2).abs();
145 let ds = ((q1 + r1) - (q2 + r2)).abs(); dq.max(dr).max(ds)
147 }
148
149 fn compile_hex_disk(
151 &self,
152 center_q: i32,
153 center_r: i32,
154 radius: u32,
155 ) -> Result<RegionPlan, SpaceError> {
156 let max_useful = (self.rows as u64 + self.cols as u64).min(i32::MAX as u64) as u32;
159 let eff_radius = radius.min(max_useful);
160 let r = eff_radius as i32;
161 let side = 2i64 * r as i64 + 1;
162 let bounding_size = side
163 .checked_mul(side)
164 .ok_or_else(|| SpaceError::InvalidRegion {
165 reason: format!(
166 "hex disk bounding area overflow: side={side} exceeds i64 when squared"
167 ),
168 })? as usize;
169 let mut valid_mask = vec![0u8; bounding_size];
170 let mut coords = Vec::new();
171 let mut tensor_indices = Vec::new();
172
173 for dr in -r..=r {
175 for dq in -r..=r {
176 if Self::cube_distance(0, 0, dq, dr) > r {
177 continue;
178 }
179 let q = center_q + dq;
180 let rv = center_r + dr;
181 if q < 0 || q >= self.cols as i32 || rv < 0 || rv >= self.rows as i32 {
182 continue;
183 }
184 let tensor_idx = ((dr + r) as i64 * side + (dq + r) as i64) as usize;
185 valid_mask[tensor_idx] = 1;
186 coords.push(smallvec![q, rv]);
187 tensor_indices.push(tensor_idx);
188 }
189 }
190
191 let mut pairs: Vec<(Coord, usize)> = coords.into_iter().zip(tensor_indices).collect();
193 pairs.sort_by(|a, b| {
194 let ar = a.0[1];
195 let aq = a.0[0];
196 let br = b.0[1];
197 let bq = b.0[0];
198 (ar, aq).cmp(&(br, bq))
199 });
200 let (coords, tensor_indices): (Vec<_>, Vec<_>) = pairs.into_iter().unzip();
201
202 Ok(RegionPlan {
203 coords,
204 tensor_indices,
205 valid_mask,
206 bounding_shape: BoundingShape::Rect(vec![side as usize, side as usize]),
207 })
208 }
209}
210
211impl Space for Hex2D {
212 fn ndim(&self) -> usize {
213 2
214 }
215
216 fn cell_count(&self) -> usize {
217 (self.rows as usize) * (self.cols as usize)
218 }
219
220 fn neighbours(&self, coord: &Coord) -> SmallVec<[Coord; 8]> {
221 let q = coord[0];
222 let r = coord[1];
223 self.neighbours_qr(q, r)
224 .into_iter()
225 .map(|(nq, nr)| smallvec![nq, nr])
226 .collect()
227 }
228
229 fn max_neighbour_degree(&self) -> usize {
230 match (self.rows, self.cols) {
231 (1, 1) => 0,
232 (1, 2) | (2, 1) => 1,
233 (1, _) | (_, 1) => 2,
234 (2, 2) => 3,
235 (2, _) | (_, 2) => 4,
236 _ => 6,
237 }
238 }
239
240 fn distance(&self, a: &Coord, b: &Coord) -> f64 {
241 Self::cube_distance(a[0], a[1], b[0], b[1]) as f64
242 }
243
244 fn compile_region(&self, spec: &RegionSpec) -> Result<RegionPlan, SpaceError> {
245 match spec {
246 RegionSpec::All => {
247 let coords = self.canonical_ordering();
248 let cell_count = coords.len();
249 let tensor_indices: Vec<usize> = (0..cell_count).collect();
250 let valid_mask = vec![1u8; cell_count];
251 Ok(RegionPlan {
252 coords,
253 tensor_indices,
254 valid_mask,
255 bounding_shape: BoundingShape::Rect(vec![
256 self.rows as usize,
257 self.cols as usize,
258 ]),
259 })
260 }
261
262 RegionSpec::Disk { center, radius } => {
263 let (cq, cr) = self.check_bounds(center)?;
264 self.compile_hex_disk(cq, cr, *radius)
265 }
266
267 RegionSpec::Neighbours { center, depth } => {
268 let (cq, cr) = self.check_bounds(center)?;
269 self.compile_hex_disk(cq, cr, *depth)
270 }
271
272 RegionSpec::Rect { min, max } => {
273 let (q_lo, r_lo) = self.check_bounds(min)?;
274 let (q_hi, r_hi) = self.check_bounds(max)?;
275 if q_lo > q_hi || r_lo > r_hi {
276 return Err(SpaceError::InvalidRegion {
277 reason: format!(
278 "Rect min ({q_lo},{r_lo}) > max ({q_hi},{r_hi}) on some axis"
279 ),
280 });
281 }
282 let mut coords = Vec::new();
284 for r in r_lo..=r_hi {
285 for q in q_lo..=q_hi {
286 coords.push(smallvec![q, r]);
287 }
288 }
289 let cell_count = coords.len();
290 let tensor_indices: Vec<usize> = (0..cell_count).collect();
291 let valid_mask = vec![1u8; cell_count];
292 let shape_rows = (r_hi - r_lo + 1) as usize;
293 let shape_cols = (q_hi - q_lo + 1) as usize;
294 Ok(RegionPlan {
295 coords,
296 tensor_indices,
297 valid_mask,
298 bounding_shape: BoundingShape::Rect(vec![shape_rows, shape_cols]),
299 })
300 }
301
302 RegionSpec::Coords(coords) => {
303 for coord in coords {
304 self.check_bounds(coord)?;
305 }
306 let mut sorted: Vec<Coord> = coords.clone();
307 sorted.sort_by(|a, b| (a[1], a[0]).cmp(&(b[1], b[0])));
309 sorted.dedup();
310 let cell_count = sorted.len();
311 let tensor_indices: Vec<usize> = (0..cell_count).collect();
312 let valid_mask = vec![1u8; cell_count];
313 Ok(RegionPlan {
314 coords: sorted,
315 tensor_indices,
316 valid_mask,
317 bounding_shape: BoundingShape::Rect(vec![cell_count]),
318 })
319 }
320 }
321 }
322
323 fn canonical_ordering(&self) -> Vec<Coord> {
324 let mut out = Vec::with_capacity(self.cell_count());
326 for r in 0..self.rows as i32 {
327 for q in 0..self.cols as i32 {
328 out.push(smallvec![q, r]);
329 }
330 }
331 out
332 }
333
334 fn canonical_rank(&self, coord: &Coord) -> Option<usize> {
335 if coord.len() != 2 {
336 return None;
337 }
338 let q = coord[0];
339 let r = coord[1];
340 if q >= 0 && q < self.cols as i32 && r >= 0 && r < self.rows as i32 {
341 Some(r as usize * self.cols as usize + q as usize)
342 } else {
343 None
344 }
345 }
346
347 fn canonical_rank_slice(&self, coord: &[i32]) -> Option<usize> {
348 if coord.len() != 2 {
349 return None;
350 }
351 let q = coord[0];
352 let r = coord[1];
353 if q >= 0 && q < self.cols as i32 && r >= 0 && r < self.rows as i32 {
354 Some(r as usize * self.cols as usize + q as usize)
355 } else {
356 None
357 }
358 }
359
360 fn instance_id(&self) -> SpaceInstanceId {
361 self.instance_id
362 }
363
364 fn topology_eq(&self, other: &dyn Space) -> bool {
365 (other as &dyn std::any::Any)
366 .downcast_ref::<Self>()
367 .is_some_and(|o| self.rows == o.rows && self.cols == o.cols)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::compliance;
375 use murk_core::Coord;
376 use proptest::prelude::*;
377
378 fn c(q: i32, r: i32) -> Coord {
379 smallvec![q, r]
380 }
381
382 #[test]
385 fn neighbours_interior() {
386 let s = Hex2D::new(5, 5).unwrap();
387 let n = s.neighbours(&c(2, 1));
388 assert_eq!(n.len(), 6);
389 assert!(n.contains(&c(3, 1))); assert!(n.contains(&c(3, 0))); assert!(n.contains(&c(2, 0))); assert!(n.contains(&c(1, 1))); assert!(n.contains(&c(1, 2))); assert!(n.contains(&c(2, 2))); }
397
398 #[test]
399 fn neighbours_corner_origin() {
400 let s = Hex2D::new(5, 5).unwrap();
401 let n = s.neighbours(&c(0, 0));
402 assert_eq!(n.len(), 2);
404 assert!(n.contains(&c(1, 0)));
405 assert!(n.contains(&c(0, 1)));
406 }
407
408 #[test]
409 fn neighbours_top_edge() {
410 let s = Hex2D::new(5, 5).unwrap();
411 let n = s.neighbours(&c(2, 0));
412 assert_eq!(n.len(), 4);
414 assert!(n.contains(&c(3, 0)));
415 assert!(n.contains(&c(1, 0)));
416 assert!(n.contains(&c(1, 1)));
417 assert!(n.contains(&c(2, 1)));
418 }
419
420 #[test]
421 fn neighbours_bottom_right_corner() {
422 let s = Hex2D::new(5, 5).unwrap();
423 let n = s.neighbours(&c(4, 4));
424 assert_eq!(n.len(), 2);
426 assert!(n.contains(&c(4, 3)));
427 assert!(n.contains(&c(3, 4)));
428 }
429
430 #[test]
433 fn distance_same_cell() {
434 let s = Hex2D::new(5, 5).unwrap();
435 assert_eq!(s.distance(&c(2, 1), &c(2, 1)), 0.0);
436 }
437
438 #[test]
439 fn distance_adjacent() {
440 let s = Hex2D::new(5, 5).unwrap();
441 assert_eq!(s.distance(&c(2, 1), &c(3, 1)), 1.0); assert_eq!(s.distance(&c(2, 1), &c(3, 0)), 1.0); }
444
445 #[test]
446 fn distance_hld_worked_example() {
447 let s = Hex2D::new(5, 5).unwrap();
449 assert_eq!(s.distance(&c(2, 1), &c(4, 0)), 2.0);
450 }
451
452 #[test]
453 fn distance_across_grid() {
454 let s = Hex2D::new(5, 5).unwrap();
455 assert_eq!(s.distance(&c(0, 0), &c(4, 4)), 8.0);
457 }
458
459 #[test]
462 fn compile_region_all() {
463 let s = Hex2D::new(5, 5).unwrap();
464 let plan = s.compile_region(&RegionSpec::All).unwrap();
465 assert_eq!(plan.cell_count(), 25);
466 assert_eq!(plan.valid_ratio(), 1.0);
467 }
468
469 #[test]
470 fn compile_region_disk_r1() {
471 let s = Hex2D::new(10, 10).unwrap();
472 let plan = s
473 .compile_region(&RegionSpec::Disk {
474 center: c(5, 5),
475 radius: 1,
476 })
477 .unwrap();
478 assert_eq!(plan.cell_count(), 7);
480 }
481
482 #[test]
483 fn compile_region_disk_r2() {
484 let s = Hex2D::new(10, 10).unwrap();
485 let plan = s
486 .compile_region(&RegionSpec::Disk {
487 center: c(5, 5),
488 radius: 2,
489 })
490 .unwrap();
491 assert_eq!(plan.cell_count(), 19);
493 }
494
495 #[test]
496 fn compile_region_disk_valid_ratio_r1() {
497 let s = Hex2D::new(10, 10).unwrap();
498 let plan = s
499 .compile_region(&RegionSpec::Disk {
500 center: c(5, 5),
501 radius: 1,
502 })
503 .unwrap();
504 let ratio = plan.valid_ratio();
506 assert!((ratio - 7.0 / 9.0).abs() < 0.01, "valid_ratio={ratio}");
507 }
508
509 #[test]
510 fn compile_region_disk_valid_ratio_r2() {
511 let s = Hex2D::new(10, 10).unwrap();
512 let plan = s
513 .compile_region(&RegionSpec::Disk {
514 center: c(5, 5),
515 radius: 2,
516 })
517 .unwrap();
518 let ratio = plan.valid_ratio();
520 assert!((ratio - 19.0 / 25.0).abs() < 0.01, "valid_ratio={ratio}");
521 }
522
523 #[test]
524 fn compile_region_disk_boundary_truncation() {
525 let s = Hex2D::new(5, 5).unwrap();
526 let plan = s
527 .compile_region(&RegionSpec::Disk {
528 center: c(0, 0),
529 radius: 2,
530 })
531 .unwrap();
532 assert!(plan.cell_count() < 19);
534 assert!(plan.cell_count() >= 1);
535 }
536
537 #[test]
538 fn compile_region_disk_huge_radius_does_not_overflow() {
539 let s = Hex2D::new(3, 3).unwrap();
541 let plan = s
542 .compile_region(&RegionSpec::Disk {
543 center: c(1, 1),
544 radius: u32::MAX,
545 })
546 .unwrap();
547 assert_eq!(plan.cell_count(), 9);
548 }
549
550 #[test]
551 fn compile_hex_disk_overflow_returns_error() {
552 let s = Hex2D::new(i32::MAX as u32, 1).unwrap();
554 let result = s.compile_region(&RegionSpec::Disk {
555 center: c(0, 0),
556 radius: u32::MAX,
557 });
558 assert!(
559 result.is_err(),
560 "should return error on bounding area overflow"
561 );
562 }
563
564 #[test]
565 fn compile_region_rect() {
566 let s = Hex2D::new(10, 10).unwrap();
567 let plan = s
568 .compile_region(&RegionSpec::Rect {
569 min: c(2, 3),
570 max: c(5, 6),
571 })
572 .unwrap();
573 assert_eq!(plan.cell_count(), 16);
575 }
576
577 #[test]
578 fn compile_region_rect_invalid() {
579 let s = Hex2D::new(10, 10).unwrap();
580 assert!(s
581 .compile_region(&RegionSpec::Rect {
582 min: c(5, 0),
583 max: c(2, 3),
584 })
585 .is_err());
586 }
587
588 #[test]
589 fn compile_region_coords() {
590 let s = Hex2D::new(5, 5).unwrap();
591 let plan = s
592 .compile_region(&RegionSpec::Coords(vec![c(3, 1), c(1, 2), c(0, 0)]))
593 .unwrap();
594 assert_eq!(plan.coords, vec![c(0, 0), c(3, 1), c(1, 2)]);
596 }
597
598 #[test]
599 fn compile_region_coords_oob() {
600 let s = Hex2D::new(5, 5).unwrap();
601 assert!(s
602 .compile_region(&RegionSpec::Coords(vec![c(10, 0)]))
603 .is_err());
604 }
605
606 #[test]
609 fn new_zero_rows_returns_error() {
610 assert!(matches!(Hex2D::new(0, 5), Err(SpaceError::EmptySpace)));
611 }
612
613 #[test]
614 fn new_zero_cols_returns_error() {
615 assert!(matches!(Hex2D::new(5, 0), Err(SpaceError::EmptySpace)));
616 }
617
618 #[test]
619 fn new_rejects_dims_exceeding_i32_max() {
620 let big = i32::MAX as u32 + 1;
621 assert!(matches!(
622 Hex2D::new(big, 5),
623 Err(SpaceError::DimensionTooLarge { name: "rows", .. })
624 ));
625 assert!(matches!(
626 Hex2D::new(5, big),
627 Err(SpaceError::DimensionTooLarge { name: "cols", .. })
628 ));
629 assert!(Hex2D::new(i32::MAX as u32, 1).is_ok());
630 }
631
632 #[test]
635 fn single_cell() {
636 let s = Hex2D::new(1, 1).unwrap();
637 assert!(s.neighbours(&c(0, 0)).is_empty());
638 assert_eq!(s.cell_count(), 1);
639 assert_eq!(s.distance(&c(0, 0), &c(0, 0)), 0.0);
640 }
641
642 #[test]
645 fn canonical_ordering_r_then_q() {
646 let s = Hex2D::new(3, 3).unwrap();
647 let order = s.canonical_ordering();
648 assert_eq!(
650 order,
651 vec![
652 c(0, 0),
653 c(1, 0),
654 c(2, 0),
655 c(0, 1),
656 c(1, 1),
657 c(2, 1),
658 c(0, 2),
659 c(1, 2),
660 c(2, 2),
661 ]
662 );
663 }
664
665 #[test]
668 fn compliance_3x3() {
669 let s = Hex2D::new(3, 3).unwrap();
670 compliance::run_full_compliance(&s);
671 }
672
673 #[test]
674 fn compliance_5x5() {
675 let s = Hex2D::new(5, 5).unwrap();
676 compliance::run_full_compliance(&s);
677 }
678
679 #[test]
680 fn compliance_8x8() {
681 let s = Hex2D::new(8, 8).unwrap();
682 compliance::run_full_compliance(&s);
683 }
684
685 #[test]
688 fn downcast_ref_hex2d() {
689 let s: Box<dyn Space> = Box::new(Hex2D::new(3, 3).unwrap());
690 assert!(s.downcast_ref::<Hex2D>().is_some());
691 assert!(s.downcast_ref::<crate::Square4>().is_none());
692 }
693
694 proptest! {
697 #[test]
698 fn distance_is_metric(
699 rows in 2u32..8,
700 cols in 2u32..8,
701 aq in 0i32..8, ar in 0i32..8,
702 bq in 0i32..8, br in 0i32..8,
703 cq in 0i32..8, cr in 0i32..8,
704 ) {
705 let aq = aq % cols as i32;
706 let ar = ar % rows as i32;
707 let bq = bq % cols as i32;
708 let br = br % rows as i32;
709 let cq = cq % cols as i32;
710 let cr = cr % rows as i32;
711 let s = Hex2D::new(rows, cols).unwrap();
712 let a: Coord = smallvec![aq, ar];
713 let b: Coord = smallvec![bq, br];
714 let cv: Coord = smallvec![cq, cr];
715
716 prop_assert!((s.distance(&a, &a) - 0.0).abs() < f64::EPSILON);
717 prop_assert!((s.distance(&a, &b) - s.distance(&b, &a)).abs() < f64::EPSILON);
718 prop_assert!(s.distance(&a, &cv) <= s.distance(&a, &b) + s.distance(&b, &cv) + f64::EPSILON);
719 }
720
721 #[test]
722 fn neighbours_symmetric(
723 rows in 2u32..8,
724 cols in 2u32..8,
725 q in 0i32..8, r in 0i32..8,
726 ) {
727 let q = q % cols as i32;
728 let r = r % rows as i32;
729 let s = Hex2D::new(rows, cols).unwrap();
730 let coord: Coord = smallvec![q, r];
731 for nb in s.neighbours(&coord) {
732 let nb_neighbours = s.neighbours(&nb);
733 prop_assert!(
734 nb_neighbours.contains(&coord),
735 "neighbour symmetry violated: {:?} in N({:?}) but {:?} not in N({:?})",
736 nb, coord, coord, nb,
737 );
738 }
739 }
740 }
741}