1use crate::iterators::{Components, Connections, Neighbors, Siblings};
7use crate::{ComponentGraph, Edge, Error, Node};
8use std::collections::BTreeSet;
9
10impl<N, E> ComponentGraph<N, E>
12where
13 N: Node,
14 E: Edge,
15{
16 pub fn component(&self, component_id: u64) -> Result<&N, Error> {
18 self.node_indices
19 .get(&component_id)
20 .map(|i| &self.graph[*i])
21 .ok_or_else(|| {
22 Error::component_not_found(format!("Component with id {component_id} not found."))
23 })
24 }
25
26 pub fn components(&self) -> Components<'_, N> {
28 Components {
29 iter: self.graph.raw_nodes().iter(),
30 }
31 }
32
33 pub fn connections(&self) -> Connections<'_, N, E> {
35 Connections {
36 cg: self,
37 iter: self.graph.raw_edges().iter(),
38 }
39 }
40
41 pub fn predecessors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
46 self.node_indices
47 .get(&component_id)
48 .map(|&index| Neighbors {
49 graph: &self.graph,
50 iter: self
51 .graph
52 .neighbors_directed(index, petgraph::Direction::Incoming),
53 })
54 .ok_or_else(|| {
55 Error::component_not_found(format!("Component with id {component_id} not found."))
56 })
57 }
58
59 pub fn successors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
64 self.node_indices
65 .get(&component_id)
66 .map(|&index| Neighbors {
67 graph: &self.graph,
68 iter: self
69 .graph
70 .neighbors_directed(index, petgraph::Direction::Outgoing),
71 })
72 .ok_or_else(|| {
73 Error::component_not_found(format!("Component with id {component_id} not found."))
74 })
75 }
76
77 pub(crate) fn siblings_from_predecessors(
82 &self,
83 component_id: u64,
84 ) -> Result<Siblings<'_, N>, Error> {
85 Ok(Siblings::new(
86 component_id,
87 self.predecessors(component_id)?
88 .map(|x| self.successors(x.component_id()))
89 .collect::<Result<Vec<_>, _>>()?
90 .into_iter()
91 .flatten(),
92 ))
93 }
94
95 pub(crate) fn siblings_from_successors(
100 &self,
101 component_id: u64,
102 ) -> Result<Siblings<'_, N>, Error> {
103 Ok(Siblings::new(
104 component_id,
105 self.successors(component_id)?
106 .map(|x| self.predecessors(x.component_id()))
107 .collect::<Result<Vec<_>, _>>()?
108 .into_iter()
109 .flatten(),
110 ))
111 }
112
113 pub(crate) fn find_all(
119 &self,
120 from: u64,
121 mut pred: impl FnMut(&N) -> bool,
122 direction: petgraph::Direction,
123 follow_after_match: bool,
124 ) -> Result<BTreeSet<u64>, Error> {
125 let index = self.node_indices.get(&from).ok_or_else(|| {
126 Error::component_not_found(format!("Component with id {from} not found."))
127 })?;
128 let mut stack = vec![*index];
129 let mut found = BTreeSet::new();
130
131 while let Some(index) = stack.pop() {
132 let node = &self.graph[index];
133 if pred(node) {
134 found.insert(node.component_id());
135 if !follow_after_match {
136 continue;
137 }
138 }
139
140 let neighbors = self.graph.neighbors_directed(index, direction);
141 stack.extend(neighbors);
142 }
143
144 Ok(found)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::ComponentCategory;
152 use crate::ComponentGraphConfig;
153 use crate::InverterType;
154 use crate::component_category::BatteryType;
155 use crate::component_category::CategoryPredicates;
156 use crate::error::Error;
157 use crate::graph::test_utils::ComponentGraphBuilder;
158 use crate::graph::test_utils::{TestComponent, TestConnection};
159
160 fn nodes_and_edges() -> (Vec<TestComponent>, Vec<TestConnection>) {
161 let components = vec![
162 TestComponent::new(6, ComponentCategory::Meter),
163 TestComponent::new(1, ComponentCategory::GridConnectionPoint),
164 TestComponent::new(7, ComponentCategory::Inverter(InverterType::Battery)),
165 TestComponent::new(3, ComponentCategory::Meter),
166 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
167 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon)),
168 TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery)),
169 TestComponent::new(2, ComponentCategory::Meter),
170 ];
171 let connections = vec![
172 TestConnection::new(3, 4),
173 TestConnection::new(1, 2),
174 TestConnection::new(7, 8),
175 TestConnection::new(4, 5),
176 TestConnection::new(2, 3),
177 TestConnection::new(6, 7),
178 TestConnection::new(2, 6),
179 ];
180
181 (components, connections)
182 }
183
184 #[test]
185 fn test_component() -> Result<(), Error> {
186 let config = ComponentGraphConfig::default();
187 let (components, connections) = nodes_and_edges();
188 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
189
190 assert_eq!(
191 graph.component(1),
192 Ok(&TestComponent::new(
193 1,
194 ComponentCategory::GridConnectionPoint
195 ))
196 );
197 assert_eq!(
198 graph.component(5),
199 Ok(&TestComponent::new(
200 5,
201 ComponentCategory::Battery(BatteryType::Unspecified)
202 ))
203 );
204 assert_eq!(
205 graph.component(9),
206 Err(Error::component_not_found("Component with id 9 not found."))
207 );
208
209 Ok(())
210 }
211
212 #[test]
213 fn test_components() -> Result<(), Error> {
214 let config = ComponentGraphConfig::default();
215 let (components, connections) = nodes_and_edges();
216 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
217
218 assert!(graph.components().eq(&components));
219 assert!(graph.components().filter(|x| x.is_battery()).eq(&[
220 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
221 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon))
222 ]));
223
224 Ok(())
225 }
226
227 #[test]
228 fn test_connections() -> Result<(), Error> {
229 let config = ComponentGraphConfig::default();
230 let (components, connections) = nodes_and_edges();
231 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
232
233 assert!(graph.connections().eq(&connections));
234
235 assert!(
236 graph
237 .connections()
238 .filter(|x| x.source() == 2)
239 .eq(&[TestConnection::new(2, 3), TestConnection::new(2, 6)])
240 );
241
242 Ok(())
243 }
244
245 #[test]
246 fn test_neighbors() -> Result<(), Error> {
247 let config = ComponentGraphConfig::default();
248 let (components, connections) = nodes_and_edges();
249 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
250
251 assert!(graph.predecessors(1).is_ok_and(|x| x.eq(&[])));
252
253 assert!(
254 graph
255 .predecessors(3)
256 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
257 );
258
259 assert!(
260 graph
261 .successors(1)
262 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
263 );
264
265 assert!(graph.successors(2).is_ok_and(|x| {
266 x.eq(&[
267 TestComponent::new(6, ComponentCategory::Meter),
268 TestComponent::new(3, ComponentCategory::Meter),
269 ])
270 }));
271
272 assert!(graph.successors(5).is_ok_and(|x| x.eq(&[])));
273
274 assert!(
275 graph
276 .predecessors(32)
277 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
278 );
279 assert!(
280 graph
281 .successors(32)
282 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
283 );
284
285 Ok(())
286 }
287
288 #[test]
289 fn test_siblings() -> Result<(), Error> {
290 let mut builder = ComponentGraphBuilder::new();
291 let grid = builder.grid();
292
293 let grid_meter = builder.meter();
295 builder.connect(grid, grid_meter);
296
297 assert_eq!(grid_meter.component_id(), 1);
298
299 let meter_bat_chain = builder.meter_bat_chain(3, 2);
301 builder.connect(grid_meter, meter_bat_chain);
302
303 assert_eq!(meter_bat_chain.component_id(), 2);
304
305 let graph = builder.build(None)?;
306 assert_eq!(
307 graph
308 .siblings_from_predecessors(3)
309 .unwrap()
310 .collect::<Vec<_>>(),
311 [
312 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
313 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
314 ]
315 );
316
317 assert_eq!(
318 graph
319 .siblings_from_successors(3)
320 .unwrap()
321 .collect::<Vec<_>>(),
322 [
323 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
324 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
325 ]
326 );
327
328 assert_eq!(
329 graph
330 .siblings_from_successors(6)
331 .unwrap()
332 .collect::<Vec<_>>(),
333 Vec::<&TestComponent>::new()
334 );
335
336 assert_eq!(
337 graph
338 .siblings_from_predecessors(6)
339 .unwrap()
340 .collect::<Vec<_>>(),
341 [&TestComponent::new(
342 7,
343 ComponentCategory::Battery(BatteryType::LiIon)
344 )]
345 );
346
347 let dangling_meter = builder.meter();
349 builder.connect(grid_meter, dangling_meter);
350 assert_eq!(dangling_meter.component_id(), 8);
351
352 let dangling_meter = builder.meter();
353 builder.connect(grid_meter, dangling_meter);
354 assert_eq!(dangling_meter.component_id(), 9);
355
356 let graph = builder.build(None)?;
357 assert_eq!(
358 graph
359 .siblings_from_predecessors(8)
360 .unwrap()
361 .collect::<Vec<_>>(),
362 [
363 &TestComponent::new(9, ComponentCategory::Meter),
364 &TestComponent::new(2, ComponentCategory::Meter),
365 ]
366 );
367
368 Ok(())
369 }
370
371 #[test]
372 fn test_find_all() -> Result<(), Error> {
373 let (components, connections) = nodes_and_edges();
374 let graph = ComponentGraph::try_new(
375 components.clone(),
376 connections.clone(),
377 ComponentGraphConfig::default(),
378 )?;
379
380 let found = graph.find_all(
381 graph.root_id,
382 |x| x.is_meter(),
383 petgraph::Direction::Outgoing,
384 false,
385 )?;
386 assert_eq!(found, [2].iter().cloned().collect());
387
388 let found = graph.find_all(
389 graph.root_id,
390 |x| x.is_meter(),
391 petgraph::Direction::Outgoing,
392 true,
393 )?;
394 assert_eq!(found, [2, 3, 6].iter().cloned().collect());
395
396 let found = graph.find_all(
397 graph.root_id,
398 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
399 petgraph::Direction::Outgoing,
400 true,
401 )?;
402 assert_eq!(found, [2, 4, 5, 7, 8].iter().cloned().collect());
403
404 let found = graph.find_all(
405 6,
406 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
407 petgraph::Direction::Outgoing,
408 true,
409 )?;
410 assert_eq!(found, [7, 8].iter().cloned().collect());
411
412 let found = graph.find_all(
413 graph.root_id,
414 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
415 petgraph::Direction::Outgoing,
416 false,
417 )?;
418 assert_eq!(found, [2].iter().cloned().collect());
419
420 let found = graph.find_all(
421 graph.root_id,
422 |_| true,
423 petgraph::Direction::Outgoing,
424 false,
425 )?;
426 assert_eq!(found, [1].iter().cloned().collect());
427
428 let found = graph.find_all(3, |_| true, petgraph::Direction::Outgoing, true)?;
429 assert_eq!(found, [3, 4, 5].iter().cloned().collect());
430
431 Ok(())
432 }
433}