1use super::*;
2use ndarray::Array2;
3use petgraph::{
4 graph::IndexType,
5 visit::{EdgeRef, IntoEdgeReferences, IntoEdges},
6};
7use std::iter::FusedIterator;
8
9impl<'a, N, E, Ix, S> IntoEdgeReferences for &'a SquareGraph<N, E, Ix, S>
10where
11 Ix: IndexType,
12 E: Copy,
13 S: Shape,
14{
15 type EdgeRef = EdgeReference<'a, E, Ix, S>;
16 type EdgeReferences = EdgeReferences<'a, E, Ix, S>;
17
18 fn edge_references(self) -> Self::EdgeReferences {
19 EdgeReferences::new(self)
20 }
21}
22
23#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
25pub struct EdgeReference<'a, E, Ix: IndexType, S: Shape> {
26 pub(crate) edge_id: EdgeIndex<Ix>,
27 pub(crate) edge_weight: &'a E,
28 pub(crate) direction: bool,
29 pub(crate) s: S::SizeInfo,
30 pub(crate) spd: PhantomData<fn() -> S>,
31}
32
33impl<'a, E, Ix: IndexType, S: Shape> Clone for EdgeReference<'a, E, Ix, S> {
34 fn clone(&self) -> Self {
35 *self
36 }
37}
38
39impl<'a, E, Ix: IndexType, S: Shape> Copy for EdgeReference<'a, E, Ix, S> {}
40
41impl<'a, E, Ix: IndexType, S: Shape> EdgeReference<'a, E, Ix, S> {
42 #[inline]
43 fn get_node(&self, is_source: bool) -> NodeIndex<Ix> {
44 let node = self.edge_id.node;
45 if is_source {
46 node
47 } else {
48 match self.edge_id.axis {
49 Axis::Horizontal => {
50 if S::LOOP_HORIZONTAL
51 && node.horizontal.index() + 1 == unsafe { self.s.horizontal_size() }
52 {
53 NodeIndex::new(Ix::new(0), node.vertical)
54 } else {
55 node.right()
56 }
57 }
58 Axis::Vertical => {
59 if S::LOOP_VERTICAL
60 && node.vertical.index() + 1 == unsafe { self.s.vertical_size() }
61 {
62 NodeIndex::new(node.horizontal, Ix::new(0))
63 } else {
64 node.up()
65 }
66 }
67 }
68 }
69 }
70}
71
72impl<'a, E: Copy, Ix: IndexType, S: Shape> EdgeRef for EdgeReference<'a, E, Ix, S> {
73 type NodeId = NodeIndex<Ix>;
74 type EdgeId = EdgeIndex<Ix>;
75 type Weight = E;
76
77 #[inline]
78 fn source(&self) -> Self::NodeId {
79 self.get_node(self.direction)
80 }
81
82 #[inline]
83 fn target(&self) -> Self::NodeId {
84 self.get_node(!self.direction)
85 }
86
87 fn weight(&self) -> &Self::Weight {
88 self.edge_weight
89 }
90
91 fn id(&self) -> Self::EdgeId {
92 self.edge_id
93 }
94}
95
96#[derive(Clone, Debug)]
98pub struct EdgeReferences<'a, E, Ix: IndexType, S> {
99 horizontal: &'a Array2<E>,
100 vertical: &'a Array2<E>,
101 nodes: NodeIndices<Ix>,
102 prv: Option<NodeIndex<Ix>>,
103 s: PhantomData<S>,
104}
105
106impl<'a, E, Ix: IndexType, S> EdgeReferences<'a, E, Ix, S> {
107 fn new<N>(graph: &'a SquareGraph<N, E, Ix, S>) -> Self {
108 Self {
109 horizontal: &graph.horizontal,
110 vertical: &graph.vertical,
111 nodes: NodeIndices::new(graph.horizontal_node_count(), graph.vertical_node_count()),
112 prv: None,
113 s: PhantomData,
114 }
115 }
116}
117
118impl<'a, E, Ix, S> Iterator for EdgeReferences<'a, E, Ix, S>
119where
120 Ix: IndexType,
121 S: Shape,
122{
123 type Item = EdgeReference<'a, E, Ix, S>;
124
125 fn next(&mut self) -> Option<Self::Item> {
126 let s = S::get_sizeinfo(self.nodes.h_max, self.nodes.v_max);
127 loop {
128 match self.prv {
129 None => {
130 let x = self.nodes.next()?;
131 let e = EdgeIndex {
132 node: x,
133 axis: Axis::Horizontal,
134 };
135 self.prv = Some(x);
136 let ew = self
137 .horizontal
138 .get((x.horizontal.index(), x.vertical.index()));
139 if let Some(ew) = ew {
140 return Some(EdgeReference {
141 edge_id: e,
142 edge_weight: ew,
143 direction: true,
144 s,
145 spd: PhantomData,
146 });
147 }
148 }
149 Some(x) => {
150 self.prv = None;
151 let ew = self
152 .vertical
153 .get((x.horizontal.index(), x.vertical.index()));
154 if let Some(ew) = ew {
155 return Some(EdgeReference {
156 edge_id: EdgeIndex {
157 node: x,
158 axis: Axis::Vertical,
159 },
160 edge_weight: ew,
161 direction: true,
162 s,
163 spd: PhantomData,
164 });
165 }
166 }
167 }
168 }
169 }
170}
171
172#[derive(Clone, Debug)]
174pub struct Edges<'a, N, E, Ix: IndexType, S> {
175 g: &'a SquareGraph<N, E, Ix, S>,
176 node: NodeIndex<Ix>,
177 state: usize,
178}
179
180impl<'a, N, E, Ix, S> Iterator for Edges<'a, N, E, Ix, S>
181where
182 Ix: IndexType,
183 S: Shape,
184{
185 type Item = EdgeReference<'a, E, Ix, S>;
186
187 fn next(&mut self) -> Option<Self::Item> {
188 let g = self.g;
189 let n = self.node;
190 let s = S::get_sizeinfo(g.horizontal_node_count(), g.vertical_node_count());
191 loop {
192 'inner: {
193 let er = match self.state {
194 0 => {
195 let new_n = if n.horizontal.index() == 0 {
196 if !S::LOOP_HORIZONTAL {
197 break 'inner;
198 }
199 NodeIndex::new(Ix::new(g.horizontal_node_count() - 1), n.vertical)
200 } else {
201 n.left()
202 };
203 EdgeReference {
204 edge_id: EdgeIndex {
205 node: new_n,
206 axis: Axis::Horizontal,
207 },
208 edge_weight: unsafe {
209 g.horizontal
210 .uget((new_n.horizontal.index(), new_n.vertical.index()))
211 },
212 direction: false,
213 s,
214 spd: PhantomData,
215 }
216 }
217 1 => {
218 debug_assert!(n.horizontal.index() < g.horizontal_node_count());
219 if !S::LOOP_HORIZONTAL
220 && n.horizontal.index() + 1 == g.horizontal_node_count()
221 {
222 break 'inner;
223 }
224 EdgeReference {
225 edge_id: EdgeIndex {
226 node: n,
227 axis: Axis::Horizontal,
228 },
229 edge_weight: unsafe {
230 g.horizontal
231 .uget((n.horizontal.index(), n.vertical.index()))
232 },
233 direction: true,
234 s,
235 spd: PhantomData,
236 }
237 }
238 2 => {
239 let new_n = if n.vertical.index() == 0 {
240 if !S::LOOP_VERTICAL {
241 break 'inner;
242 }
243 NodeIndex::new(n.horizontal, Ix::new(g.vertical_node_count() - 1))
244 } else {
245 n.down()
246 };
247 EdgeReference {
248 edge_id: EdgeIndex {
249 node: new_n,
250 axis: Axis::Vertical,
251 },
252 edge_weight: unsafe {
253 g.vertical
254 .uget((new_n.horizontal.index(), new_n.vertical.index()))
255 },
256 direction: false,
257 s,
258 spd: PhantomData,
259 }
260 }
261 3 => {
262 debug_assert!(n.vertical.index() < g.vertical_node_count());
263 if !S::LOOP_VERTICAL && n.vertical.index() + 1 == g.vertical_node_count() {
264 break 'inner;
265 }
266 EdgeReference {
267 edge_id: EdgeIndex {
268 node: n,
269 axis: Axis::Vertical,
270 },
271 edge_weight: unsafe {
272 g.vertical.uget((n.horizontal.index(), n.vertical.index()))
273 },
274 direction: true,
275 s,
276 spd: PhantomData,
277 }
278 }
279 _ => return None,
280 };
281 self.state += 1;
282 return Some(er);
283 }
284 self.state += 1;
285 }
286 }
287
288 fn size_hint(&self) -> (usize, Option<usize>) {
289 (0, Some(4 - self.state))
290 }
291}
292
293impl<'a, N, E, Ix, S> FusedIterator for Edges<'a, N, E, Ix, S>
294where
295 Ix: IndexType,
296 S: Shape,
297{
298}
299
300impl<'a, N, E, Ix, S> IntoEdges for &'a SquareGraph<N, E, Ix, S>
301where
302 Ix: IndexType,
303 E: Copy,
304 S: Shape,
305{
306 type Edges = Edges<'a, N, E, Ix, S>;
307
308 fn edges(self, a: Self::NodeId) -> Self::Edges {
309 Edges {
310 g: self,
311 node: a,
312 state: 0,
313 }
314 }
315}
316
317