1use crate::iterators::{Components, Connections, Neighbors, RawNeighbors, Siblings};
7use crate::{ComponentGraph, Edge, Error, Node};
8use petgraph::graph::NodeIndex;
9use std::collections::{BTreeSet, HashSet, VecDeque};
10
11impl<N, E> ComponentGraph<N, E>
13where
14 N: Node,
15 E: Edge,
16{
17 pub fn component(&self, component_id: u64) -> Result<&N, Error> {
19 self.node_indices
20 .get(&component_id)
21 .map(|i| &self.graph[*i])
22 .ok_or_else(|| {
23 Error::component_not_found(format!("Component with id {component_id} not found."))
24 })
25 }
26
27 pub fn components(&self) -> Components<'_, N> {
29 Components {
30 iter: self.graph.raw_nodes().iter(),
31 }
32 }
33
34 pub fn connections(&self) -> Connections<'_, N, E> {
36 Connections {
37 cg: self,
38 iter: self.graph.raw_edges().iter(),
39 }
40 }
41
42 pub fn raw_predecessors(&self, component_id: u64) -> Result<RawNeighbors<'_, N>, Error> {
52 self.raw_neighbors(component_id, petgraph::Direction::Incoming)
53 }
54
55 pub fn raw_successors(&self, component_id: u64) -> Result<RawNeighbors<'_, N>, Error> {
65 self.raw_neighbors(component_id, petgraph::Direction::Outgoing)
66 }
67
68 fn raw_neighbors(
71 &self,
72 component_id: u64,
73 direction: petgraph::Direction,
74 ) -> Result<RawNeighbors<'_, N>, Error> {
75 self.node_indices
76 .get(&component_id)
77 .map(|&index| RawNeighbors {
78 graph: &self.graph,
79 iter: self.graph.neighbors_directed(index, direction),
80 })
81 .ok_or_else(|| {
82 Error::component_not_found(format!("Component with id {component_id} not found."))
83 })
84 }
85
86 pub fn predecessors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
97 self.collect_effective_neighbors(component_id, petgraph::Direction::Incoming)
98 }
99
100 pub fn successors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
111 self.collect_effective_neighbors(component_id, petgraph::Direction::Outgoing)
112 }
113
114 fn collect_effective_neighbors(
117 &self,
118 component_id: u64,
119 direction: petgraph::Direction,
120 ) -> Result<Neighbors<'_, N>, Error> {
121 let start = *self.node_indices.get(&component_id).ok_or_else(|| {
122 Error::component_not_found(format!("Component with id {component_id} not found."))
123 })?;
124
125 let mut queue: VecDeque<NodeIndex> =
126 self.graph.neighbors_directed(start, direction).collect();
127 let mut visited: HashSet<NodeIndex> = HashSet::new();
128 let mut result: Vec<&N> = Vec::new();
129
130 while let Some(idx) = queue.pop_front() {
131 if !visited.insert(idx) {
132 continue;
133 }
134 let node = &self.graph[idx];
135 if node.category().is_passthrough() {
136 queue.extend(self.graph.neighbors_directed(idx, direction));
137 } else {
138 result.push(node);
139 }
140 }
141
142 Ok(Neighbors {
143 iter: result.into_iter(),
144 })
145 }
146
147 pub(crate) fn siblings_from_predecessors(
152 &self,
153 component_id: u64,
154 ) -> Result<Siblings<'_, N>, Error> {
155 Ok(Siblings::new(
156 component_id,
157 self.predecessors(component_id)?
158 .map(|x| self.successors(x.component_id()))
159 .collect::<Result<Vec<_>, _>>()?
160 .into_iter()
161 .flatten(),
162 ))
163 }
164
165 pub(crate) fn siblings_from_successors(
170 &self,
171 component_id: u64,
172 ) -> Result<Siblings<'_, N>, Error> {
173 Ok(Siblings::new(
174 component_id,
175 self.successors(component_id)?
176 .map(|x| self.predecessors(x.component_id()))
177 .collect::<Result<Vec<_>, _>>()?
178 .into_iter()
179 .flatten(),
180 ))
181 }
182
183 pub(crate) fn find_all(
189 &self,
190 from: u64,
191 mut pred: impl FnMut(&N) -> bool,
192 direction: petgraph::Direction,
193 follow_after_match: bool,
194 ) -> Result<BTreeSet<u64>, Error> {
195 let index = self.node_indices.get(&from).ok_or_else(|| {
196 Error::component_not_found(format!("Component with id {from} not found."))
197 })?;
198 let mut stack = vec![*index];
199 let mut found = BTreeSet::new();
200
201 while let Some(index) = stack.pop() {
202 let node = &self.graph[index];
203 if !node.category().is_passthrough() && pred(node) {
206 found.insert(node.component_id());
207 if !follow_after_match {
208 continue;
209 }
210 }
211
212 let neighbors = self.graph.neighbors_directed(index, direction);
213 stack.extend(neighbors);
214 }
215
216 Ok(found)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::ComponentCategory;
224 use crate::ComponentGraphConfig;
225 use crate::InverterType;
226 use crate::component_category::BatteryType;
227 use crate::component_category::CategoryPredicates;
228 use crate::error::Error;
229 use crate::graph::test_utils::ComponentGraphBuilder;
230 use crate::graph::test_utils::{TestComponent, TestConnection};
231
232 fn nodes_and_edges() -> (Vec<TestComponent>, Vec<TestConnection>) {
233 let components = vec![
234 TestComponent::new(6, ComponentCategory::Meter),
235 TestComponent::new(1, ComponentCategory::GridConnectionPoint),
236 TestComponent::new(7, ComponentCategory::Inverter(InverterType::Battery)),
237 TestComponent::new(3, ComponentCategory::Meter),
238 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
239 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon)),
240 TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery)),
241 TestComponent::new(2, ComponentCategory::Meter),
242 ];
243 let connections = vec![
244 TestConnection::new(3, 4),
245 TestConnection::new(1, 2),
246 TestConnection::new(7, 8),
247 TestConnection::new(4, 5),
248 TestConnection::new(2, 3),
249 TestConnection::new(6, 7),
250 TestConnection::new(2, 6),
251 ];
252
253 (components, connections)
254 }
255
256 #[test]
257 fn test_component() -> Result<(), Error> {
258 let config = ComponentGraphConfig::default();
259 let (components, connections) = nodes_and_edges();
260 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
261
262 assert_eq!(
263 graph.component(1),
264 Ok(&TestComponent::new(
265 1,
266 ComponentCategory::GridConnectionPoint
267 ))
268 );
269 assert_eq!(
270 graph.component(5),
271 Ok(&TestComponent::new(
272 5,
273 ComponentCategory::Battery(BatteryType::Unspecified)
274 ))
275 );
276 assert_eq!(
277 graph.component(9),
278 Err(Error::component_not_found("Component with id 9 not found."))
279 );
280
281 Ok(())
282 }
283
284 #[test]
285 fn test_components() -> Result<(), Error> {
286 let config = ComponentGraphConfig::default();
287 let (components, connections) = nodes_and_edges();
288 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
289
290 assert!(graph.components().eq(&components));
291 assert!(graph.components().filter(|x| x.is_battery()).eq(&[
292 TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
293 TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon))
294 ]));
295
296 Ok(())
297 }
298
299 #[test]
300 fn test_connections() -> Result<(), Error> {
301 let config = ComponentGraphConfig::default();
302 let (components, connections) = nodes_and_edges();
303 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
304
305 assert!(graph.connections().eq(&connections));
306
307 assert!(
308 graph
309 .connections()
310 .filter(|x| x.source() == 2)
311 .eq(&[TestConnection::new(2, 3), TestConnection::new(2, 6)])
312 );
313
314 Ok(())
315 }
316
317 #[test]
318 fn test_neighbors() -> Result<(), Error> {
319 let config = ComponentGraphConfig::default();
320 let (components, connections) = nodes_and_edges();
321 let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
322
323 assert!(graph.predecessors(1).is_ok_and(|x| x.eq(&[])));
324
325 assert!(
326 graph
327 .predecessors(3)
328 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
329 );
330
331 assert!(
332 graph
333 .successors(1)
334 .is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
335 );
336
337 assert!(graph.successors(2).is_ok_and(|x| {
338 x.eq(&[
339 TestComponent::new(6, ComponentCategory::Meter),
340 TestComponent::new(3, ComponentCategory::Meter),
341 ])
342 }));
343
344 assert!(graph.successors(5).is_ok_and(|x| x.eq(&[])));
345
346 assert!(
347 graph
348 .predecessors(32)
349 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
350 );
351 assert!(
352 graph
353 .successors(32)
354 .is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
355 );
356
357 Ok(())
358 }
359
360 #[test]
361 fn test_siblings() -> Result<(), Error> {
362 let mut builder = ComponentGraphBuilder::new();
363 let grid = builder.grid();
364
365 let grid_meter = builder.meter();
367 builder.connect(grid, grid_meter);
368
369 assert_eq!(grid_meter.component_id(), 1);
370
371 let meter_bat_chain = builder.meter_bat_chain(3, 2);
373 builder.connect(grid_meter, meter_bat_chain);
374
375 assert_eq!(meter_bat_chain.component_id(), 2);
376
377 let graph = builder.build(None)?;
378 assert_eq!(
379 graph
380 .siblings_from_predecessors(3)
381 .unwrap()
382 .collect::<Vec<_>>(),
383 [
384 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
385 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
386 ]
387 );
388
389 assert_eq!(
390 graph
391 .siblings_from_successors(3)
392 .unwrap()
393 .collect::<Vec<_>>(),
394 [
395 &TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
396 &TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
397 ]
398 );
399
400 assert_eq!(
401 graph
402 .siblings_from_successors(6)
403 .unwrap()
404 .collect::<Vec<_>>(),
405 Vec::<&TestComponent>::new()
406 );
407
408 assert_eq!(
409 graph
410 .siblings_from_predecessors(6)
411 .unwrap()
412 .collect::<Vec<_>>(),
413 [&TestComponent::new(
414 7,
415 ComponentCategory::Battery(BatteryType::LiIon)
416 )]
417 );
418
419 let dangling_meter = builder.meter();
421 builder.connect(grid_meter, dangling_meter);
422 assert_eq!(dangling_meter.component_id(), 8);
423
424 let dangling_meter = builder.meter();
425 builder.connect(grid_meter, dangling_meter);
426 assert_eq!(dangling_meter.component_id(), 9);
427
428 let graph = builder.build(None)?;
429 assert_eq!(
430 graph
431 .siblings_from_predecessors(8)
432 .unwrap()
433 .collect::<Vec<_>>(),
434 [
435 &TestComponent::new(9, ComponentCategory::Meter),
436 &TestComponent::new(2, ComponentCategory::Meter),
437 ]
438 );
439
440 Ok(())
441 }
442
443 #[test]
449 fn test_raw_neighbors_includes_passthroughs() -> Result<(), Error> {
450 let mut builder = ComponentGraphBuilder::new();
451 let grid = builder.grid();
452 let pt = builder.power_transformer();
453 let meter = builder.meter();
454 let inverter = builder.battery_inverter();
455 let battery = builder.battery();
456
457 builder.connect(grid, pt);
458 builder.connect(pt, meter);
459 builder.connect(meter, inverter);
460 builder.connect(inverter, battery);
461
462 let graph = builder.build(None)?;
463
464 let raw_preds: Vec<u64> = graph
466 .raw_predecessors(meter.component_id())?
467 .map(|n| n.component_id())
468 .collect();
469 assert_eq!(raw_preds, vec![pt.component_id()]);
470
471 let raw_succs: Vec<u64> = graph
472 .raw_successors(grid.component_id())?
473 .map(|n| n.component_id())
474 .collect();
475 assert_eq!(raw_succs, vec![pt.component_id()]);
476
477 let preds: Vec<u64> = graph
479 .predecessors(meter.component_id())?
480 .map(|n| n.component_id())
481 .collect();
482 assert_eq!(preds, vec![grid.component_id()]);
483
484 let succs: Vec<u64> = graph
485 .successors(grid.component_id())?
486 .map(|n| n.component_id())
487 .collect();
488 assert_eq!(succs, vec![meter.component_id()]);
489
490 assert!(graph.raw_predecessors(999).is_err());
492 assert!(graph.raw_successors(999).is_err());
493
494 let _ = (battery, inverter);
497 Ok(())
498 }
499
500 #[test]
506 fn test_find_all_skips_passthroughs() -> Result<(), Error> {
507 let mut builder = ComponentGraphBuilder::new();
508 let grid = builder.grid();
509 let pt = builder.power_transformer();
510 let meter = builder.meter();
511
512 builder.connect(grid, pt);
513 builder.connect(pt, meter);
514
515 let graph = builder.build(None)?;
516
517 let found = graph.find_all(
519 grid.component_id(),
520 |_| true,
521 petgraph::Direction::Outgoing,
522 true,
523 )?;
524 assert_eq!(
525 found,
526 BTreeSet::from([grid.component_id(), meter.component_id()])
527 );
528
529 let found = graph.find_all(
531 grid.component_id(),
532 |n| n.category() == ComponentCategory::PowerTransformer,
533 petgraph::Direction::Outgoing,
534 true,
535 )?;
536 assert!(found.is_empty());
537 Ok(())
538 }
539
540 #[test]
541 fn test_find_all() -> Result<(), Error> {
542 let (components, connections) = nodes_and_edges();
543 let graph = ComponentGraph::try_new(
544 components.clone(),
545 connections.clone(),
546 ComponentGraphConfig::default(),
547 )?;
548
549 let found = graph.find_all(
550 graph.root_id,
551 |x| x.is_meter(),
552 petgraph::Direction::Outgoing,
553 false,
554 )?;
555 assert_eq!(found, [2].iter().cloned().collect());
556
557 let found = graph.find_all(
558 graph.root_id,
559 |x| x.is_meter(),
560 petgraph::Direction::Outgoing,
561 true,
562 )?;
563 assert_eq!(found, [2, 3, 6].iter().cloned().collect());
564
565 let found = graph.find_all(
566 graph.root_id,
567 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
568 petgraph::Direction::Outgoing,
569 true,
570 )?;
571 assert_eq!(found, [2, 4, 5, 7, 8].iter().cloned().collect());
572
573 let found = graph.find_all(
574 6,
575 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
576 petgraph::Direction::Outgoing,
577 true,
578 )?;
579 assert_eq!(found, [7, 8].iter().cloned().collect());
580
581 let found = graph.find_all(
582 graph.root_id,
583 |x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
584 petgraph::Direction::Outgoing,
585 false,
586 )?;
587 assert_eq!(found, [2].iter().cloned().collect());
588
589 let found = graph.find_all(
590 graph.root_id,
591 |_| true,
592 petgraph::Direction::Outgoing,
593 false,
594 )?;
595 assert_eq!(found, [1].iter().cloned().collect());
596
597 let found = graph.find_all(3, |_| true, petgraph::Direction::Outgoing, true)?;
598 assert_eq!(found, [3, 4, 5].iter().cloned().collect());
599
600 Ok(())
601 }
602}