lattice_graph/square/
edges.rs

1use super::*;
2use ndarray::Array2;
3use petgraph::{
4    graph::IndexType,
5    visit::{EdgeRef, IntoEdgeReferences, IntoEdges},
6};
7use std::iter::FusedIterator;
8
9impl<'a, N, E, Ix, S> IntoEdgeReferences for &'a SquareGraph<N, E, Ix, S>
10where
11    Ix: IndexType,
12    E: Copy,
13    S: Shape,
14{
15    type EdgeRef = EdgeReference<'a, E, Ix, S>;
16    type EdgeReferences = EdgeReferences<'a, E, Ix, S>;
17
18    fn edge_references(self) -> Self::EdgeReferences {
19        EdgeReferences::new(self)
20    }
21}
22
23/// Reference of Edge data (EdgeIndex, EdgeWeight, direction) in [`SquareGraph`].
24#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
25pub struct EdgeReference<'a, E, Ix: IndexType, S: Shape> {
26    pub(crate) edge_id: EdgeIndex<Ix>,
27    pub(crate) edge_weight: &'a E,
28    pub(crate) direction: bool,
29    pub(crate) s: S::SizeInfo,
30    pub(crate) spd: PhantomData<fn() -> S>,
31}
32
33impl<'a, E, Ix: IndexType, S: Shape> Clone for EdgeReference<'a, E, Ix, S> {
34    fn clone(&self) -> Self {
35        *self
36    }
37}
38
39impl<'a, E, Ix: IndexType, S: Shape> Copy for EdgeReference<'a, E, Ix, S> {}
40
41impl<'a, E, Ix: IndexType, S: Shape> EdgeReference<'a, E, Ix, S> {
42    #[inline]
43    fn get_node(&self, is_source: bool) -> NodeIndex<Ix> {
44        let node = self.edge_id.node;
45        if is_source {
46            node
47        } else {
48            match self.edge_id.axis {
49                Axis::Horizontal => {
50                    if S::LOOP_HORIZONTAL
51                        && node.horizontal.index() + 1 == unsafe { self.s.horizontal_size() }
52                    {
53                        NodeIndex::new(Ix::new(0), node.vertical)
54                    } else {
55                        node.right()
56                    }
57                }
58                Axis::Vertical => {
59                    if S::LOOP_VERTICAL
60                        && node.vertical.index() + 1 == unsafe { self.s.vertical_size() }
61                    {
62                        NodeIndex::new(node.horizontal, Ix::new(0))
63                    } else {
64                        node.up()
65                    }
66                }
67            }
68        }
69    }
70}
71
72impl<'a, E: Copy, Ix: IndexType, S: Shape> EdgeRef for EdgeReference<'a, E, Ix, S> {
73    type NodeId = NodeIndex<Ix>;
74    type EdgeId = EdgeIndex<Ix>;
75    type Weight = E;
76
77    #[inline]
78    fn source(&self) -> Self::NodeId {
79        self.get_node(self.direction)
80    }
81
82    #[inline]
83    fn target(&self) -> Self::NodeId {
84        self.get_node(!self.direction)
85    }
86
87    fn weight(&self) -> &Self::Weight {
88        self.edge_weight
89    }
90
91    fn id(&self) -> Self::EdgeId {
92        self.edge_id
93    }
94}
95
96/// Iterator for all edges of [`SquareGraph`]. See [`IntoEdgeReferences`](`IntoEdgeReferences::edge_references`).
97#[derive(Clone, Debug)]
98pub struct EdgeReferences<'a, E, Ix: IndexType, S> {
99    horizontal: &'a Array2<E>,
100    vertical: &'a Array2<E>,
101    nodes: NodeIndices<Ix>,
102    prv: Option<NodeIndex<Ix>>,
103    s: PhantomData<S>,
104}
105
106impl<'a, E, Ix: IndexType, S> EdgeReferences<'a, E, Ix, S> {
107    fn new<N>(graph: &'a SquareGraph<N, E, Ix, S>) -> Self {
108        Self {
109            horizontal: &graph.horizontal,
110            vertical: &graph.vertical,
111            nodes: NodeIndices::new(graph.horizontal_node_count(), graph.vertical_node_count()),
112            prv: None,
113            s: PhantomData,
114        }
115    }
116}
117
118impl<'a, E, Ix, S> Iterator for EdgeReferences<'a, E, Ix, S>
119where
120    Ix: IndexType,
121    S: Shape,
122{
123    type Item = EdgeReference<'a, E, Ix, S>;
124
125    fn next(&mut self) -> Option<Self::Item> {
126        let s = S::get_sizeinfo(self.nodes.h_max, self.nodes.v_max);
127        loop {
128            match self.prv {
129                None => {
130                    let x = self.nodes.next()?;
131                    let e = EdgeIndex {
132                        node: x,
133                        axis: Axis::Horizontal,
134                    };
135                    self.prv = Some(x);
136                    let ew = self
137                        .horizontal
138                        .get((x.horizontal.index(), x.vertical.index()));
139                    if let Some(ew) = ew {
140                        return Some(EdgeReference {
141                            edge_id: e,
142                            edge_weight: ew,
143                            direction: true,
144                            s,
145                            spd: PhantomData,
146                        });
147                    }
148                }
149                Some(x) => {
150                    self.prv = None;
151                    let ew = self
152                        .vertical
153                        .get((x.horizontal.index(), x.vertical.index()));
154                    if let Some(ew) = ew {
155                        return Some(EdgeReference {
156                            edge_id: EdgeIndex {
157                                node: x,
158                                axis: Axis::Vertical,
159                            },
160                            edge_weight: ew,
161                            direction: true,
162                            s,
163                            spd: PhantomData,
164                        });
165                    }
166                }
167            }
168        }
169    }
170}
171
172/// Edges connected to a node. See [`edges`][`IntoEdges::edges`].
173#[derive(Clone, Debug)]
174pub struct Edges<'a, N, E, Ix: IndexType, S> {
175    g: &'a SquareGraph<N, E, Ix, S>,
176    node: NodeIndex<Ix>,
177    state: usize,
178}
179
180impl<'a, N, E, Ix, S> Iterator for Edges<'a, N, E, Ix, S>
181where
182    Ix: IndexType,
183    S: Shape,
184{
185    type Item = EdgeReference<'a, E, Ix, S>;
186
187    fn next(&mut self) -> Option<Self::Item> {
188        let g = self.g;
189        let n = self.node;
190        let s = S::get_sizeinfo(g.horizontal_node_count(), g.vertical_node_count());
191        loop {
192            'inner: {
193                let er = match self.state {
194                    0 => {
195                        let new_n = if n.horizontal.index() == 0 {
196                            if !S::LOOP_HORIZONTAL {
197                                break 'inner;
198                            }
199                            NodeIndex::new(Ix::new(g.horizontal_node_count() - 1), n.vertical)
200                        } else {
201                            n.left()
202                        };
203                        EdgeReference {
204                            edge_id: EdgeIndex {
205                                node: new_n,
206                                axis: Axis::Horizontal,
207                            },
208                            edge_weight: unsafe {
209                                g.horizontal
210                                    .uget((new_n.horizontal.index(), new_n.vertical.index()))
211                            },
212                            direction: false,
213                            s,
214                            spd: PhantomData,
215                        }
216                    }
217                    1 => {
218                        debug_assert!(n.horizontal.index() < g.horizontal_node_count());
219                        if !S::LOOP_HORIZONTAL
220                            && n.horizontal.index() + 1 == g.horizontal_node_count()
221                        {
222                            break 'inner;
223                        }
224                        EdgeReference {
225                            edge_id: EdgeIndex {
226                                node: n,
227                                axis: Axis::Horizontal,
228                            },
229                            edge_weight: unsafe {
230                                g.horizontal
231                                    .uget((n.horizontal.index(), n.vertical.index()))
232                            },
233                            direction: true,
234                            s,
235                            spd: PhantomData,
236                        }
237                    }
238                    2 => {
239                        let new_n = if n.vertical.index() == 0 {
240                            if !S::LOOP_VERTICAL {
241                                break 'inner;
242                            }
243                            NodeIndex::new(n.horizontal, Ix::new(g.vertical_node_count() - 1))
244                        } else {
245                            n.down()
246                        };
247                        EdgeReference {
248                            edge_id: EdgeIndex {
249                                node: new_n,
250                                axis: Axis::Vertical,
251                            },
252                            edge_weight: unsafe {
253                                g.vertical
254                                    .uget((new_n.horizontal.index(), new_n.vertical.index()))
255                            },
256                            direction: false,
257                            s,
258                            spd: PhantomData,
259                        }
260                    }
261                    3 => {
262                        debug_assert!(n.vertical.index() < g.vertical_node_count());
263                        if !S::LOOP_VERTICAL && n.vertical.index() + 1 == g.vertical_node_count() {
264                            break 'inner;
265                        }
266                        EdgeReference {
267                            edge_id: EdgeIndex {
268                                node: n,
269                                axis: Axis::Vertical,
270                            },
271                            edge_weight: unsafe {
272                                g.vertical.uget((n.horizontal.index(), n.vertical.index()))
273                            },
274                            direction: true,
275                            s,
276                            spd: PhantomData,
277                        }
278                    }
279                    _ => return None,
280                };
281                self.state += 1;
282                return Some(er);
283            }
284            self.state += 1;
285        }
286    }
287
288    fn size_hint(&self) -> (usize, Option<usize>) {
289        (0, Some(4 - self.state))
290    }
291}
292
293impl<'a, N, E, Ix, S> FusedIterator for Edges<'a, N, E, Ix, S>
294where
295    Ix: IndexType,
296    S: Shape,
297{
298}
299
300impl<'a, N, E, Ix, S> IntoEdges for &'a SquareGraph<N, E, Ix, S>
301where
302    Ix: IndexType,
303    E: Copy,
304    S: Shape,
305{
306    type Edges = Edges<'a, N, E, Ix, S>;
307
308    fn edges(self, a: Self::NodeId) -> Self::Edges {
309        Edges {
310            g: self,
311            node: a,
312            state: 0,
313        }
314    }
315}
316
317// impl<'a, N, E, Ix, S> IntoEdgesDirected for &'a SquareGraph<N, E, Ix, S>
318// where
319//     Ix: IndexType,
320//     E: Copy,
321//     S: Shape,
322// {
323//     type EdgesDirected = Edges<'a, N, E, Ix, S>;
324
325//     fn edges_directed(self, a: Self::NodeId, _dir: petgraph::EdgeDirection) -> Self::EdgesDirected {
326//         self.edges(a)
327//     }
328// }