1use crate::iterators::{Components, Connections, Neighbors, Siblings};
7use crate::{ComponentGraph, Edge, Error, Node};
8use std::collections::BTreeSet;
9
10impl<N, E> ComponentGraph<N, E>
12where
13 N: Node,
14 E: Edge,
15{
16 pub fn component(&self, component_id: u64) -> Result<&N, Error> {
18 self.node_indices
19 .get(&component_id)
20 .map(|i| &self.graph[*i])
21 .ok_or_else(|| {
22 Error::component_not_found(format!("Component with id {component_id} not found."))
23 })
24 }
25
26 pub fn components(&self) -> Components<'_, N> {
28 Components {
29 iter: self.graph.raw_nodes().iter(),
30 }
31 }
32
33 pub fn connections(&self) -> Connections<'_, N, E> {
35 Connections {
36 cg: self,
37 iter: self.graph.raw_edges().iter(),
38 }
39 }
40
41 pub fn predecessors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
46 self.node_indices
47 .get(&component_id)
48 .map(|&index| Neighbors {
49 graph: &self.graph,
50 iter: self
51 .graph
52 .neighbors_directed(index, petgraph::Direction::Incoming),
53 })
54 .ok_or_else(|| {
55 Error::component_not_found(format!("Component with id {component_id} not found."))
56 })
57 }
58
59 pub fn successors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
64 self.node_indices
65 .get(&component_id)
66 .map(|&index| Neighbors {
67 graph: &self.graph,
68 iter: self
69 .graph
70 .neighbors_directed(index, petgraph::Direction::Outgoing),
71 })
72 .ok_or_else(|| {
73 Error::component_not_found(format!("Component with id {component_id} not found."))
74 })
75 }
76
77 pub(crate) fn siblings_from_predecessors(
82 &self,
83 component_id: u64,
84 ) -> Result<Siblings<'_, N>, Error> {
85 Ok(Siblings::new(
86 component_id,
87 self.predecessors(component_id)?
88 .map(|x| self.successors(x.component_id()))
89 .collect::<Result<Vec<_>, _>>()?
90 .into_iter()
91 .flatten(),
92 ))
93 }
94
95 pub(crate) fn siblings_from_successors(
100 &self,
101 component_id: u64,
102 ) -> Result<Siblings<'_, N>, Error> {
103 Ok(Siblings::new(
104 component_id,
105 self.successors(component_id)?
106 .map(|x| self.predecessors(x.component_id()))
107 .collect::<Result<Vec<_>, _>>()?
108 .into_iter()
109 .flatten(),
110 ))
111 }
112
113 pub(crate) fn find_all(
119 &self,
120 from: u64,
121 mut pred: impl FnMut(&N) -> bool,
122 direction: petgraph::Direction,
123 follow_after_match: bool,
124 ) -> Result<BTreeSet<u64>, Error> {
125 let index = self.node_indices.get(&from).ok_or_else(|| {
126 Error::component_not_found(format!("Component with id {from} not found."))
127 })?;
128 let mut stack = vec![*index];
129 let mut found = BTreeSet::new();
130
131 while let Some(index) = stack.pop() {
132 let node = &self.graph[index];
133 if pred(node) {
134 found.insert(node.component_id());
135 if !follow_after_match {
136 continue;
137 }
138 }
139
140 let neighbors = self.graph.neighbors_directed(index, direction);
141 stack.extend(neighbors);
142 }
143
144 Ok(found)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::component_category::BatteryType;
152 use crate::component_category::CategoryPredicates;
153 use crate::error::Error;
154 use crate::graph::test_utils::ComponentGraphBuilder;
155 use crate::graph::test_utils::{TestComponent, TestConnection};
156 use crate::ComponentCategory;
157 use crate::ComponentGraphConfig;
158 use crate::InverterType;
159
160 fn nodes_and_edges() -> (Vec<TestComponent>, Vec<TestConnection>) {
161 let components = vec![
162 TestComponent::new(6, ComponentCategory::Meter),
163 TestComponent::new(1, ComponentCategory::GridConnectionPoint),
164 TestComponent::new(7, ComponentCategory::Inverter(InverterType::Battery)),
165 TestComponent::new(3, ComponentCategory::Meter),
166 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
167 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon)),
168 TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery)),
169 TestComponent::new(2, ComponentCategory::Meter),
170 ];
171 let connections = vec![
172 TestConnection::new(3, 4),
173 TestConnection::new(1, 2),
174 TestConnection::new(7, 8),
175 TestConnection::new(4, 5),
176 TestConnection::new(2, 3),
177 TestConnection::new(6, 7),
178 TestConnection::new(2, 6),
179 ];
180
181 (components, connections)
182 }
183
184 #[test]
185 fn test_component() -> Result<(), Error> {
186 let config = ComponentGraphConfig::default();
187 let (components, connections) = nodes_and_edges();
188 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
189
190 assert_eq!(
191 graph.component(1),
192 Ok(&TestComponent::new(
193 1,
194 ComponentCategory::GridConnectionPoint
195 ))
196 );
197 assert_eq!(
198 graph.component(5),
199 Ok(&TestComponent::new(
200 5,
201 ComponentCategory::Battery(BatteryType::Unspecified)
202 ))
203 );
204 assert_eq!(
205 graph.component(9),
206 Err(Error::component_not_found("Component with id 9 not found."))
207 );
208
209 Ok(())
210 }
211
212 #[test]
213 fn test_components() -> Result<(), Error> {
214 let config = ComponentGraphConfig::default();
215 let (components, connections) = nodes_and_edges();
216 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
217
218 assert!(graph.components().eq(&components));
219 assert!(graph.components().filter(|x| x.is_battery()).eq(&[
220 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
221 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon))
222 ]));
223
224 Ok(())
225 }
226
227 #[test]
228 fn test_connections() -> Result<(), Error> {
229 let config = ComponentGraphConfig::default();
230 let (components, connections) = nodes_and_edges();
231 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
232
233 assert!(graph.connections().eq(&connections));
234
235 assert!(graph
236 .connections()
237 .filter(|x| x.source() == 2)
238 .eq(&[TestConnection::new(2, 3), TestConnection::new(2, 6)]));
239
240 Ok(())
241 }
242
243 #[test]
244 fn test_neighbors() -> Result<(), Error> {
245 let config = ComponentGraphConfig::default();
246 let (components, connections) = nodes_and_edges();
247 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
248
249 assert!(graph.predecessors(1).is_ok_and(|x| x.eq(&[])));
250
251 assert!(graph
252 .predecessors(3)
253 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)])));
254
255 assert!(graph
256 .successors(1)
257 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)])));
258
259 assert!(graph.successors(2).is_ok_and(|x| {
260 x.eq(&[
261 TestComponent::new(6, ComponentCategory::Meter),
262 TestComponent::new(3, ComponentCategory::Meter),
263 ])
264 }));
265
266 assert!(graph.successors(5).is_ok_and(|x| x.eq(&[])));
267
268 assert!(graph
269 .predecessors(32)
270 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found.")));
271 assert!(graph
272 .successors(32)
273 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found.")));
274
275 Ok(())
276 }
277
278 #[test]
279 fn test_siblings() -> Result<(), Error> {
280 let mut builder = ComponentGraphBuilder::new();
281 let grid = builder.grid();
282
283 let grid_meter = builder.meter();
285 builder.connect(grid, grid_meter);
286
287 assert_eq!(grid_meter.component_id(), 1);
288
289 let meter_bat_chain = builder.meter_bat_chain(3, 2);
291 builder.connect(grid_meter, meter_bat_chain);
292
293 assert_eq!(meter_bat_chain.component_id(), 2);
294
295 let graph = builder.build(None)?;
296 assert_eq!(
297 graph
298 .siblings_from_predecessors(3)
299 .unwrap()
300 .collect::<Vec<_>>(),
301 [
302 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
303 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
304 ]
305 );
306
307 assert_eq!(
308 graph
309 .siblings_from_successors(3)
310 .unwrap()
311 .collect::<Vec<_>>(),
312 [
313 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
314 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
315 ]
316 );
317
318 assert_eq!(
319 graph
320 .siblings_from_successors(6)
321 .unwrap()
322 .collect::<Vec<_>>(),
323 Vec::<&TestComponent>::new()
324 );
325
326 assert_eq!(
327 graph
328 .siblings_from_predecessors(6)
329 .unwrap()
330 .collect::<Vec<_>>(),
331 [&TestComponent::new(
332 7,
333 ComponentCategory::Battery(BatteryType::LiIon)
334 )]
335 );
336
337 let dangling_meter = builder.meter();
339 builder.connect(grid_meter, dangling_meter);
340 assert_eq!(dangling_meter.component_id(), 8);
341
342 let dangling_meter = builder.meter();
343 builder.connect(grid_meter, dangling_meter);
344 assert_eq!(dangling_meter.component_id(), 9);
345
346 let graph = builder.build(None)?;
347 assert_eq!(
348 graph
349 .siblings_from_predecessors(8)
350 .unwrap()
351 .collect::<Vec<_>>(),
352 [
353 &TestComponent::new(9, ComponentCategory::Meter),
354 &TestComponent::new(2, ComponentCategory::Meter),
355 ]
356 );
357
358 Ok(())
359 }
360
361 #[test]
362 fn test_find_all() -> Result<(), Error> {
363 let (components, connections) = nodes_and_edges();
364 let graph = ComponentGraph::try_new(
365 components.clone(),
366 connections.clone(),
367 ComponentGraphConfig::default(),
368 )?;
369
370 let found = graph.find_all(
371 graph.root_id,
372 |x| x.is_meter(),
373 petgraph::Direction::Outgoing,
374 false,
375 )?;
376 assert_eq!(found, [2].iter().cloned().collect());
377
378 let found = graph.find_all(
379 graph.root_id,
380 |x| x.is_meter(),
381 petgraph::Direction::Outgoing,
382 true,
383 )?;
384 assert_eq!(found, [2, 3, 6].iter().cloned().collect());
385
386 let found = graph.find_all(
387 graph.root_id,
388 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
389 petgraph::Direction::Outgoing,
390 true,
391 )?;
392 assert_eq!(found, [2, 4, 5, 7, 8].iter().cloned().collect());
393
394 let found = graph.find_all(
395 6,
396 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
397 petgraph::Direction::Outgoing,
398 true,
399 )?;
400 assert_eq!(found, [7, 8].iter().cloned().collect());
401
402 let found = graph.find_all(
403 graph.root_id,
404 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
405 petgraph::Direction::Outgoing,
406 false,
407 )?;
408 assert_eq!(found, [2].iter().cloned().collect());
409
410 let found = graph.find_all(
411 graph.root_id,
412 |_| true,
413 petgraph::Direction::Outgoing,
414 false,
415 )?;
416 assert_eq!(found, [1].iter().cloned().collect());
417
418 let found = graph.find_all(3, |_| true, petgraph::Direction::Outgoing, true)?;
419 assert_eq!(found, [3, 4, 5].iter().cloned().collect());
420
421 Ok(())
422 }
423}