lattice_graph/square/
mod.rs

1//! Square 2d Lattice Graph. It does not use [`lattice_abstract`](`crate::lattice_abstract`) for historical and performance reason.
2
3use crate::unreachable_debug_checked;
4use ndarray::Array2;
5
6use fixedbitset::FixedBitSet;
7use petgraph::{
8    data::{DataMap, DataMapMut},
9    graph::IndexType,
10    visit::{
11        Data, GraphBase, GraphProp, IntoNodeIdentifiers, IntoNodeReferences, NodeIndexable,
12        VisitMap, Visitable,
13    },
14    Undirected,
15};
16use std::{iter::FusedIterator, marker::PhantomData, ops::Range};
17
18mod edges;
19pub use edges::*;
20mod index;
21pub use index::*;
22mod neighbors;
23pub use neighbors::*;
24mod nodes;
25pub use nodes::*;
26
27#[cfg(test)]
28mod tests;
29
30/// Shape of the [`SquareGraph`]. It tells that the graph loops or not.
31pub trait Shape: Copy {
32    /// SizeInfo is needed if loop is enabled.
33    type SizeInfo: SizeInfo;
34    /// Whether the graph loops in horizontal axis.
35    const LOOP_HORIZONTAL: bool = false;
36    /// Whether the graph loops in vertical axis.
37    const LOOP_VERTICAL: bool = false;
38    /// Get a size info used in [`EdgeReference`].
39    fn get_sizeinfo(h: usize, v: usize) -> Self::SizeInfo;
40}
41
42/// It holds a infomation of size of graph if needed.
43/// This is used in [`EdgeReference`] to tell the loop info.
44/// This trick is to optimize when there is no loop in graph.
45pub trait SizeInfo: Copy {
46    /// Should only be called when [`Shape::LOOP_HORIZONTAL`] is true.
47    unsafe fn horizontal_size(&self) -> usize {
48        unreachable_debug_checked()
49    }
50    /// Should only be called when [`Shape::LOOP_VERTICAL`] is true.
51    unsafe fn vertical_size(&self) -> usize {
52        unreachable_debug_checked()
53    }
54}
55
56impl SizeInfo for () {}
57impl SizeInfo for (usize, ()) {
58    unsafe fn horizontal_size(&self) -> usize {
59        self.0
60    }
61}
62impl SizeInfo for ((), usize) {
63    unsafe fn vertical_size(&self) -> usize {
64        self.1
65    }
66}
67impl SizeInfo for (usize, usize) {
68    unsafe fn horizontal_size(&self) -> usize {
69        self.0
70    }
71    unsafe fn vertical_size(&self) -> usize {
72        self.1
73    }
74}
75
76/// Marker that the graph does not loop.
77#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
78pub enum DefaultShape {}
79impl Shape for DefaultShape {
80    type SizeInfo = ();
81    #[inline]
82    fn get_sizeinfo(_h: usize, _v: usize) -> Self::SizeInfo {}
83}
84/// Marker that the graph does loops in horizontal axis.
85#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
86pub enum HorizontalLoop {}
87impl Shape for HorizontalLoop {
88    type SizeInfo = (usize, ());
89    const LOOP_HORIZONTAL: bool = true;
90    #[inline]
91    fn get_sizeinfo(h: usize, _v: usize) -> Self::SizeInfo {
92        (h, ())
93    }
94}
95/// Marker that the graph does loops in vertical axis.
96#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
97pub enum VerticalLoop {}
98impl Shape for VerticalLoop {
99    type SizeInfo = ((), usize);
100    const LOOP_VERTICAL: bool = true;
101    #[inline]
102    fn get_sizeinfo(_h: usize, v: usize) -> Self::SizeInfo {
103        ((), v)
104    }
105}
106/// Marker that the graph does loops in horizontal and vertical axis.
107#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
108pub enum HVLoop {}
109impl Shape for HVLoop {
110    type SizeInfo = (usize, usize);
111    const LOOP_VERTICAL: bool = true;
112    const LOOP_HORIZONTAL: bool = true;
113    #[inline]
114    fn get_sizeinfo(h: usize, v: usize) -> Self::SizeInfo {
115        (h, v)
116    }
117}
118
119/// Undirected Square Grid Graph. It is has rectangle shape.
120/// ```text
121/// Node(i,j+1) - Edge(i,j+1,Horizontal) - Node(i+1,j+1)
122///   |                                     |
123/// Edge(i,j,Vertical)                     Edge(i+1,j,Vertical)
124///   |                                     |
125/// Node(i,j)   - Edge(i,j,Horizontal)   - Node(i+1,j)
126/// ```
127#[derive(Clone, Debug)]
128pub struct SquareGraph<N, E, Ix = usize, S = DefaultShape>
129where
130    Ix: IndexType,
131{
132    /// `[horizontal][vertical]`
133    nodes: Array2<N>,
134    horizontal: Array2<E>, //→
135    vertical: Array2<E>,   //↑
136    s: PhantomData<fn() -> S>,
137    pd: PhantomData<fn() -> Ix>,
138}
139
140impl<N, E, Ix, S> SquareGraph<N, E, Ix, S>
141where
142    Ix: IndexType,
143    S: Shape,
144{
145    /// Create a `SquareGraph` from raw data.
146    /// It only check whether the size of nodes and edges are correct in `debug_assertion`.
147    pub unsafe fn new_raw(nodes: Array2<N>, horizontal: Array2<E>, vertical: Array2<E>) -> Self {
148        let s = Self {
149            nodes,
150            horizontal,
151            vertical,
152            s: PhantomData,
153            pd: PhantomData,
154        };
155        debug_assert!(s.check_gen());
156        s
157    }
158
159    /// Create a `SquareGraph` with the nodes and edges initialized with default.
160    pub fn new(h: usize, v: usize) -> Self
161    where
162        N: Default,
163        E: Default,
164    {
165        Self::new_with(h, v, |_, _| N::default(), |_, _, _| E::default())
166    }
167
168    /// Creates a `SquareGraph` with initializing nodes and edges from position.
169    pub fn new_with<FN, FE>(h: usize, v: usize, mut fnode: FN, mut fedge: FE) -> Self
170    where
171        FN: FnMut(usize, usize) -> N,
172        FE: FnMut(usize, usize, Axis) -> E,
173    {
174        assert!(h > 0, "h must be non zero");
175
176        // Initialize nodes array
177        let mut nodes_vec = Vec::with_capacity(h * v);
178        for hi in 0..h {
179            for vi in 0..v {
180                nodes_vec.push(fnode(hi, vi));
181            }
182        }
183        let nodes = Array2::from_shape_vec((h, v), nodes_vec).expect("Array2 creation failed");
184
185        // Initialize horizontal edges
186        let mh = if <S as Shape>::LOOP_HORIZONTAL {
187            h
188        } else {
189            h - 1
190        };
191        let mut horizontal_vec = Vec::with_capacity(mh * v);
192        for hi in 0..mh {
193            for vi in 0..v {
194                horizontal_vec.push(fedge(hi, vi, Axis::Horizontal));
195            }
196        }
197        let horizontal =
198            Array2::from_shape_vec((mh, v), horizontal_vec).expect("Array2 creation failed");
199
200        // Initialize vertical edges
201        let mv = if <S as Shape>::LOOP_VERTICAL {
202            v
203        } else {
204            v - 1
205        };
206        let mut vertical_vec = Vec::with_capacity(h * mv);
207        for hi in 0..h {
208            for vi in 0..mv {
209                vertical_vec.push(fedge(hi, vi, Axis::Vertical));
210            }
211        }
212        let vertical =
213            Array2::from_shape_vec((h, mv), vertical_vec).expect("Array2 creation failed");
214
215        unsafe { Self::new_raw(nodes, horizontal, vertical) }
216    }
217
218    /// Check the size of nodes and edges.
219    fn check_gen(&self) -> bool {
220        self.nodes.nrows()
221            == self.horizontal.nrows() + if <S as Shape>::LOOP_HORIZONTAL { 0 } else { 1 }
222            && self.nodes.ncols() == self.horizontal.ncols()
223            && self.nodes.nrows() == self.vertical.nrows()
224            && self.nodes.ncols()
225                == self.vertical.ncols() + if <S as Shape>::LOOP_VERTICAL { 0 } else { 1 }
226    }
227
228    #[inline]
229    /// Get the edge from node.
230    pub fn get_edge_id(
231        &self,
232        node: NodeIndex<Ix>,
233        dir: SquareDirection,
234    ) -> Option<(EdgeIndex<Ix>, bool)> {
235        let x = match dir {
236            SquareDirection::Foward(a @ Axis::Vertical)
237                if node.vertical.index() + 1 < self.vertical_node_count() =>
238            {
239                (node, a, true)
240            }
241            SquareDirection::Foward(a @ Axis::Vertical)
242                if <S as Shape>::LOOP_VERTICAL
243                    && node.vertical.index() + 1 == self.vertical_node_count() =>
244            {
245                (
246                    NodeIndex {
247                        horizontal: node.horizontal,
248                        vertical: Ix::default(),
249                    },
250                    a,
251                    true,
252                )
253            }
254            SquareDirection::Foward(a @ Axis::Horizontal)
255                if node.horizontal.index() + 1 < self.horizontal_node_count() =>
256            {
257                (node, a, true)
258            }
259            SquareDirection::Foward(a @ Axis::Horizontal)
260                if <S as Shape>::LOOP_HORIZONTAL
261                    && node.horizontal.index() + 1 == self.horizontal_node_count() =>
262            {
263                (
264                    NodeIndex {
265                        horizontal: Ix::default(),
266                        vertical: node.vertical,
267                    },
268                    a,
269                    true,
270                )
271            }
272            SquareDirection::Backward(a @ Axis::Vertical) if node.vertical.index() != 0 => {
273                (node.down(), a, false)
274            }
275            SquareDirection::Backward(a @ Axis::Vertical)
276                if <S as Shape>::LOOP_VERTICAL && node.vertical.index() == 0 =>
277            {
278                (
279                    NodeIndex {
280                        horizontal: node.horizontal,
281                        vertical: Ix::new(self.vertical_node_count() - 1),
282                    },
283                    a,
284                    false,
285                )
286            }
287            SquareDirection::Backward(a @ Axis::Horizontal) if node.horizontal.index() != 0 => {
288                (node.left(), a, false)
289            }
290            SquareDirection::Backward(a @ Axis::Horizontal)
291                if <S as Shape>::LOOP_HORIZONTAL && node.horizontal.index() == 0 =>
292            {
293                (
294                    NodeIndex {
295                        horizontal: Ix::new(self.horizontal_node_count() - 1),
296                        vertical: node.vertical,
297                    },
298                    a,
299                    false,
300                )
301            }
302            _ => return None,
303        };
304        Some(((x.0, x.1).into(), x.2))
305    }
306
307    #[inline]
308    /// Get the edge reference form node.
309    pub fn get_edge_reference(
310        &self,
311        n: NodeIndex<Ix>,
312        dir: SquareDirection,
313    ) -> Option<EdgeReference<'_, E, Ix, S>> {
314        self.get_edge_id(n, dir).map(|(e, fo)| EdgeReference {
315            edge_id: e,
316            edge_weight: unsafe {
317                if dir.is_horizontal() {
318                    self.horizontal
319                        .uget((e.node.horizontal.index(), e.node.vertical.index()))
320                } else {
321                    self.vertical
322                        .uget((e.node.horizontal.index(), e.node.vertical.index()))
323                }
324            },
325            direction: fo,
326            s: S::get_sizeinfo(self.horizontal_node_count(), self.vertical_node_count()),
327            spd: PhantomData,
328        })
329    }
330}
331
332impl<N, E, Ix, S> SquareGraph<N, E, Ix, S>
333where
334    Ix: IndexType,
335{
336    /// Returns the Node count in the horizontal direction.
337    pub fn horizontal_node_count(&self) -> usize {
338        self.nodes.nrows()
339    }
340
341    /// Returns the Node count in the vertical direction.
342    pub fn vertical_node_count(&self) -> usize {
343        self.nodes.ncols()
344    }
345
346    /// Get a reference to the nodes. `[horizontal][vertical]`
347    pub fn nodes(&self) -> &Array2<N> {
348        &self.nodes
349    }
350
351    /// Get a reference to the horizontal edges. `[horizontal][vertical]`
352    pub fn horizontal(&self) -> &Array2<E> {
353        &self.horizontal
354    }
355
356    /// Get a reference to the vertical edges. `[horizontal][vertical]`
357    pub fn vertical(&self) -> &Array2<E> {
358        &self.vertical
359    }
360
361    /// Get a mutable reference to the nodes. `[horizontal][vertical]`
362    pub fn nodes_mut(&mut self) -> &mut Array2<N> {
363        &mut self.nodes
364    }
365
366    /// Get a mutable reference to the horizontal edges. `[horizontal][vertical]`
367    pub fn horizontal_mut(&mut self) -> &mut Array2<E> {
368        &mut self.horizontal
369    }
370
371    /// Get a mutable reference to the vertical edges.
372    pub fn vertical_mut(&mut self) -> &mut Array2<E> {
373        &mut self.vertical
374    }
375}
376
377impl<E, Ix, S> SquareGraph<(), E, Ix, S>
378where
379    Ix: IndexType,
380    S: Shape,
381{
382    /// Create a `SquareGraph` with the edges initialized from position.
383    pub fn new_edge_graph<FE>(h: usize, v: usize, fedge: FE) -> Self
384    where
385        FE: FnMut(usize, usize, Axis) -> E,
386    {
387        Self::new_with(h, v, |_, _| (), fedge)
388    }
389}
390
391impl<N, E, Ix, S> GraphBase for SquareGraph<N, E, Ix, S>
392where
393    Ix: IndexType,
394{
395    type NodeId = NodeIndex<Ix>;
396    type EdgeId = EdgeIndex<Ix>;
397}
398
399impl<N, E, Ix, S> Data for SquareGraph<N, E, Ix, S>
400where
401    Ix: IndexType,
402{
403    type NodeWeight = N;
404    type EdgeWeight = E;
405}
406
407impl<N, E, Ix, S> DataMap for SquareGraph<N, E, Ix, S>
408where
409    Ix: IndexType,
410{
411    fn node_weight(&self, id: Self::NodeId) -> Option<&Self::NodeWeight> {
412        self.nodes.get((id.horizontal.index(), id.vertical.index()))
413    }
414
415    fn edge_weight(&self, id: Self::EdgeId) -> Option<&Self::EdgeWeight> {
416        match id.axis {
417            Axis::Horizontal => &self.horizontal,
418            Axis::Vertical => &self.vertical,
419        }
420        .get((id.node.horizontal.index(), id.node.vertical.index()))
421    }
422}
423
424impl<N, E, Ix, S> DataMapMut for SquareGraph<N, E, Ix, S>
425where
426    Ix: IndexType,
427{
428    fn node_weight_mut(&mut self, id: Self::NodeId) -> Option<&mut Self::NodeWeight> {
429        self.nodes
430            .get_mut((id.horizontal.index(), id.vertical.index()))
431    }
432
433    fn edge_weight_mut(&mut self, id: Self::EdgeId) -> Option<&mut Self::EdgeWeight> {
434        match id.axis {
435            Axis::Horizontal => &mut self.horizontal,
436            Axis::Vertical => &mut self.vertical,
437        }
438        .get_mut((id.node.horizontal.index(), id.node.vertical.index()))
439    }
440}
441
442impl<N, E, Ix, S> GraphProp for SquareGraph<N, E, Ix, S>
443where
444    Ix: IndexType,
445{
446    type EdgeType = Undirected;
447}
448
449/// [`VisitMap`] of [`SquareGraph`].
450#[derive(Debug, Clone, Hash, PartialEq, Eq)]
451pub struct VisMap {
452    v: Vec<FixedBitSet>,
453}
454
455impl VisMap {
456    pub fn new(h: usize, v: usize) -> Self {
457        let mut vec = Vec::with_capacity(h);
458        for _ in 0..h {
459            vec.push(FixedBitSet::with_capacity(v));
460        }
461        Self { v: vec }
462    }
463}
464
465impl<Ix: IndexType> VisitMap<NodeIndex<Ix>> for VisMap {
466    fn visit(&mut self, a: NodeIndex<Ix>) -> bool {
467        !self.v[a.horizontal.index()].put(a.vertical.index())
468    }
469
470    fn is_visited(&self, a: &NodeIndex<Ix>) -> bool {
471        self.v[a.horizontal.index()].contains(a.vertical.index())
472    }
473
474    fn unvisit(&mut self, a: NodeIndex<Ix>) -> bool {
475        self.v[a.horizontal.index()].set(a.vertical.index(), false);
476        true
477    }
478}
479
480impl<N, E, Ix, S> Visitable for SquareGraph<N, E, Ix, S>
481where
482    Ix: IndexType,
483{
484    type Map = VisMap;
485
486    fn visit_map(&self) -> Self::Map {
487        VisMap::new(self.horizontal_node_count(), self.vertical_node_count())
488    }
489
490    fn reset_map(&self, map: &mut Self::Map) {
491        map.v.iter_mut().for_each(|x| x.clear())
492    }
493}