1use crate::unreachable_debug_checked;
4use ndarray::Array2;
5
6use fixedbitset::FixedBitSet;
7use petgraph::{
8 data::{DataMap, DataMapMut},
9 graph::IndexType,
10 visit::{
11 Data, GraphBase, GraphProp, IntoNodeIdentifiers, IntoNodeReferences, NodeIndexable,
12 VisitMap, Visitable,
13 },
14 Undirected,
15};
16use std::{iter::FusedIterator, marker::PhantomData, ops::Range};
17
18mod edges;
19pub use edges::*;
20mod index;
21pub use index::*;
22mod neighbors;
23pub use neighbors::*;
24mod nodes;
25pub use nodes::*;
26
27#[cfg(test)]
28mod tests;
29
30pub trait Shape: Copy {
32 type SizeInfo: SizeInfo;
34 const LOOP_HORIZONTAL: bool = false;
36 const LOOP_VERTICAL: bool = false;
38 fn get_sizeinfo(h: usize, v: usize) -> Self::SizeInfo;
40}
41
42pub trait SizeInfo: Copy {
46 unsafe fn horizontal_size(&self) -> usize {
48 unreachable_debug_checked()
49 }
50 unsafe fn vertical_size(&self) -> usize {
52 unreachable_debug_checked()
53 }
54}
55
56impl SizeInfo for () {}
57impl SizeInfo for (usize, ()) {
58 unsafe fn horizontal_size(&self) -> usize {
59 self.0
60 }
61}
62impl SizeInfo for ((), usize) {
63 unsafe fn vertical_size(&self) -> usize {
64 self.1
65 }
66}
67impl SizeInfo for (usize, usize) {
68 unsafe fn horizontal_size(&self) -> usize {
69 self.0
70 }
71 unsafe fn vertical_size(&self) -> usize {
72 self.1
73 }
74}
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
78pub enum DefaultShape {}
79impl Shape for DefaultShape {
80 type SizeInfo = ();
81 #[inline]
82 fn get_sizeinfo(_h: usize, _v: usize) -> Self::SizeInfo {}
83}
84#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
86pub enum HorizontalLoop {}
87impl Shape for HorizontalLoop {
88 type SizeInfo = (usize, ());
89 const LOOP_HORIZONTAL: bool = true;
90 #[inline]
91 fn get_sizeinfo(h: usize, _v: usize) -> Self::SizeInfo {
92 (h, ())
93 }
94}
95#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
97pub enum VerticalLoop {}
98impl Shape for VerticalLoop {
99 type SizeInfo = ((), usize);
100 const LOOP_VERTICAL: bool = true;
101 #[inline]
102 fn get_sizeinfo(_h: usize, v: usize) -> Self::SizeInfo {
103 ((), v)
104 }
105}
106#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
108pub enum HVLoop {}
109impl Shape for HVLoop {
110 type SizeInfo = (usize, usize);
111 const LOOP_VERTICAL: bool = true;
112 const LOOP_HORIZONTAL: bool = true;
113 #[inline]
114 fn get_sizeinfo(h: usize, v: usize) -> Self::SizeInfo {
115 (h, v)
116 }
117}
118
119#[derive(Clone, Debug)]
128pub struct SquareGraph<N, E, Ix = usize, S = DefaultShape>
129where
130 Ix: IndexType,
131{
132 nodes: Array2<N>,
134 horizontal: Array2<E>, vertical: Array2<E>, s: PhantomData<fn() -> S>,
137 pd: PhantomData<fn() -> Ix>,
138}
139
140impl<N, E, Ix, S> SquareGraph<N, E, Ix, S>
141where
142 Ix: IndexType,
143 S: Shape,
144{
145 pub unsafe fn new_raw(nodes: Array2<N>, horizontal: Array2<E>, vertical: Array2<E>) -> Self {
148 let s = Self {
149 nodes,
150 horizontal,
151 vertical,
152 s: PhantomData,
153 pd: PhantomData,
154 };
155 debug_assert!(s.check_gen());
156 s
157 }
158
159 pub fn new(h: usize, v: usize) -> Self
161 where
162 N: Default,
163 E: Default,
164 {
165 Self::new_with(h, v, |_, _| N::default(), |_, _, _| E::default())
166 }
167
168 pub fn new_with<FN, FE>(h: usize, v: usize, mut fnode: FN, mut fedge: FE) -> Self
170 where
171 FN: FnMut(usize, usize) -> N,
172 FE: FnMut(usize, usize, Axis) -> E,
173 {
174 assert!(h > 0, "h must be non zero");
175
176 let mut nodes_vec = Vec::with_capacity(h * v);
178 for hi in 0..h {
179 for vi in 0..v {
180 nodes_vec.push(fnode(hi, vi));
181 }
182 }
183 let nodes = Array2::from_shape_vec((h, v), nodes_vec).expect("Array2 creation failed");
184
185 let mh = if <S as Shape>::LOOP_HORIZONTAL {
187 h
188 } else {
189 h - 1
190 };
191 let mut horizontal_vec = Vec::with_capacity(mh * v);
192 for hi in 0..mh {
193 for vi in 0..v {
194 horizontal_vec.push(fedge(hi, vi, Axis::Horizontal));
195 }
196 }
197 let horizontal =
198 Array2::from_shape_vec((mh, v), horizontal_vec).expect("Array2 creation failed");
199
200 let mv = if <S as Shape>::LOOP_VERTICAL {
202 v
203 } else {
204 v - 1
205 };
206 let mut vertical_vec = Vec::with_capacity(h * mv);
207 for hi in 0..h {
208 for vi in 0..mv {
209 vertical_vec.push(fedge(hi, vi, Axis::Vertical));
210 }
211 }
212 let vertical =
213 Array2::from_shape_vec((h, mv), vertical_vec).expect("Array2 creation failed");
214
215 unsafe { Self::new_raw(nodes, horizontal, vertical) }
216 }
217
218 fn check_gen(&self) -> bool {
220 self.nodes.nrows()
221 == self.horizontal.nrows() + if <S as Shape>::LOOP_HORIZONTAL { 0 } else { 1 }
222 && self.nodes.ncols() == self.horizontal.ncols()
223 && self.nodes.nrows() == self.vertical.nrows()
224 && self.nodes.ncols()
225 == self.vertical.ncols() + if <S as Shape>::LOOP_VERTICAL { 0 } else { 1 }
226 }
227
228 #[inline]
229 pub fn get_edge_id(
231 &self,
232 node: NodeIndex<Ix>,
233 dir: SquareDirection,
234 ) -> Option<(EdgeIndex<Ix>, bool)> {
235 let x = match dir {
236 SquareDirection::Foward(a @ Axis::Vertical)
237 if node.vertical.index() + 1 < self.vertical_node_count() =>
238 {
239 (node, a, true)
240 }
241 SquareDirection::Foward(a @ Axis::Vertical)
242 if <S as Shape>::LOOP_VERTICAL
243 && node.vertical.index() + 1 == self.vertical_node_count() =>
244 {
245 (
246 NodeIndex {
247 horizontal: node.horizontal,
248 vertical: Ix::default(),
249 },
250 a,
251 true,
252 )
253 }
254 SquareDirection::Foward(a @ Axis::Horizontal)
255 if node.horizontal.index() + 1 < self.horizontal_node_count() =>
256 {
257 (node, a, true)
258 }
259 SquareDirection::Foward(a @ Axis::Horizontal)
260 if <S as Shape>::LOOP_HORIZONTAL
261 && node.horizontal.index() + 1 == self.horizontal_node_count() =>
262 {
263 (
264 NodeIndex {
265 horizontal: Ix::default(),
266 vertical: node.vertical,
267 },
268 a,
269 true,
270 )
271 }
272 SquareDirection::Backward(a @ Axis::Vertical) if node.vertical.index() != 0 => {
273 (node.down(), a, false)
274 }
275 SquareDirection::Backward(a @ Axis::Vertical)
276 if <S as Shape>::LOOP_VERTICAL && node.vertical.index() == 0 =>
277 {
278 (
279 NodeIndex {
280 horizontal: node.horizontal,
281 vertical: Ix::new(self.vertical_node_count() - 1),
282 },
283 a,
284 false,
285 )
286 }
287 SquareDirection::Backward(a @ Axis::Horizontal) if node.horizontal.index() != 0 => {
288 (node.left(), a, false)
289 }
290 SquareDirection::Backward(a @ Axis::Horizontal)
291 if <S as Shape>::LOOP_HORIZONTAL && node.horizontal.index() == 0 =>
292 {
293 (
294 NodeIndex {
295 horizontal: Ix::new(self.horizontal_node_count() - 1),
296 vertical: node.vertical,
297 },
298 a,
299 false,
300 )
301 }
302 _ => return None,
303 };
304 Some(((x.0, x.1).into(), x.2))
305 }
306
307 #[inline]
308 pub fn get_edge_reference(
310 &self,
311 n: NodeIndex<Ix>,
312 dir: SquareDirection,
313 ) -> Option<EdgeReference<'_, E, Ix, S>> {
314 self.get_edge_id(n, dir).map(|(e, fo)| EdgeReference {
315 edge_id: e,
316 edge_weight: unsafe {
317 if dir.is_horizontal() {
318 self.horizontal
319 .uget((e.node.horizontal.index(), e.node.vertical.index()))
320 } else {
321 self.vertical
322 .uget((e.node.horizontal.index(), e.node.vertical.index()))
323 }
324 },
325 direction: fo,
326 s: S::get_sizeinfo(self.horizontal_node_count(), self.vertical_node_count()),
327 spd: PhantomData,
328 })
329 }
330}
331
332impl<N, E, Ix, S> SquareGraph<N, E, Ix, S>
333where
334 Ix: IndexType,
335{
336 pub fn horizontal_node_count(&self) -> usize {
338 self.nodes.nrows()
339 }
340
341 pub fn vertical_node_count(&self) -> usize {
343 self.nodes.ncols()
344 }
345
346 pub fn nodes(&self) -> &Array2<N> {
348 &self.nodes
349 }
350
351 pub fn horizontal(&self) -> &Array2<E> {
353 &self.horizontal
354 }
355
356 pub fn vertical(&self) -> &Array2<E> {
358 &self.vertical
359 }
360
361 pub fn nodes_mut(&mut self) -> &mut Array2<N> {
363 &mut self.nodes
364 }
365
366 pub fn horizontal_mut(&mut self) -> &mut Array2<E> {
368 &mut self.horizontal
369 }
370
371 pub fn vertical_mut(&mut self) -> &mut Array2<E> {
373 &mut self.vertical
374 }
375}
376
377impl<E, Ix, S> SquareGraph<(), E, Ix, S>
378where
379 Ix: IndexType,
380 S: Shape,
381{
382 pub fn new_edge_graph<FE>(h: usize, v: usize, fedge: FE) -> Self
384 where
385 FE: FnMut(usize, usize, Axis) -> E,
386 {
387 Self::new_with(h, v, |_, _| (), fedge)
388 }
389}
390
391impl<N, E, Ix, S> GraphBase for SquareGraph<N, E, Ix, S>
392where
393 Ix: IndexType,
394{
395 type NodeId = NodeIndex<Ix>;
396 type EdgeId = EdgeIndex<Ix>;
397}
398
399impl<N, E, Ix, S> Data for SquareGraph<N, E, Ix, S>
400where
401 Ix: IndexType,
402{
403 type NodeWeight = N;
404 type EdgeWeight = E;
405}
406
407impl<N, E, Ix, S> DataMap for SquareGraph<N, E, Ix, S>
408where
409 Ix: IndexType,
410{
411 fn node_weight(&self, id: Self::NodeId) -> Option<&Self::NodeWeight> {
412 self.nodes.get((id.horizontal.index(), id.vertical.index()))
413 }
414
415 fn edge_weight(&self, id: Self::EdgeId) -> Option<&Self::EdgeWeight> {
416 match id.axis {
417 Axis::Horizontal => &self.horizontal,
418 Axis::Vertical => &self.vertical,
419 }
420 .get((id.node.horizontal.index(), id.node.vertical.index()))
421 }
422}
423
424impl<N, E, Ix, S> DataMapMut for SquareGraph<N, E, Ix, S>
425where
426 Ix: IndexType,
427{
428 fn node_weight_mut(&mut self, id: Self::NodeId) -> Option<&mut Self::NodeWeight> {
429 self.nodes
430 .get_mut((id.horizontal.index(), id.vertical.index()))
431 }
432
433 fn edge_weight_mut(&mut self, id: Self::EdgeId) -> Option<&mut Self::EdgeWeight> {
434 match id.axis {
435 Axis::Horizontal => &mut self.horizontal,
436 Axis::Vertical => &mut self.vertical,
437 }
438 .get_mut((id.node.horizontal.index(), id.node.vertical.index()))
439 }
440}
441
442impl<N, E, Ix, S> GraphProp for SquareGraph<N, E, Ix, S>
443where
444 Ix: IndexType,
445{
446 type EdgeType = Undirected;
447}
448
449#[derive(Debug, Clone, Hash, PartialEq, Eq)]
451pub struct VisMap {
452 v: Vec<FixedBitSet>,
453}
454
455impl VisMap {
456 pub fn new(h: usize, v: usize) -> Self {
457 let mut vec = Vec::with_capacity(h);
458 for _ in 0..h {
459 vec.push(FixedBitSet::with_capacity(v));
460 }
461 Self { v: vec }
462 }
463}
464
465impl<Ix: IndexType> VisitMap<NodeIndex<Ix>> for VisMap {
466 fn visit(&mut self, a: NodeIndex<Ix>) -> bool {
467 !self.v[a.horizontal.index()].put(a.vertical.index())
468 }
469
470 fn is_visited(&self, a: &NodeIndex<Ix>) -> bool {
471 self.v[a.horizontal.index()].contains(a.vertical.index())
472 }
473
474 fn unvisit(&mut self, a: NodeIndex<Ix>) -> bool {
475 self.v[a.horizontal.index()].set(a.vertical.index(), false);
476 true
477 }
478}
479
480impl<N, E, Ix, S> Visitable for SquareGraph<N, E, Ix, S>
481where
482 Ix: IndexType,
483{
484 type Map = VisMap;
485
486 fn visit_map(&self) -> Self::Map {
487 VisMap::new(self.horizontal_node_count(), self.vertical_node_count())
488 }
489
490 fn reset_map(&self, map: &mut Self::Map) {
491 map.v.iter_mut().for_each(|x| x.clear())
492 }
493}