use crate::iterators::{Components, Connections, Neighbors, RawNeighbors, Siblings};
use crate::{ComponentGraph, Edge, Error, Node};
use petgraph::graph::NodeIndex;
use std::collections::{BTreeSet, HashSet, VecDeque};
impl<N, E> ComponentGraph<N, E>
where
N: Node,
E: Edge,
{
pub fn component(&self, component_id: u64) -> Result<&N, Error> {
self.node_indices
.get(&component_id)
.map(|i| &self.graph[*i])
.ok_or_else(|| {
Error::component_not_found(format!("Component with id {component_id} not found."))
})
}
pub fn components(&self) -> Components<'_, N> {
Components {
iter: self.graph.raw_nodes().iter(),
}
}
pub fn connections(&self) -> Connections<'_, N, E> {
Connections {
cg: self,
iter: self.graph.raw_edges().iter(),
}
}
pub fn raw_predecessors(&self, component_id: u64) -> Result<RawNeighbors<'_, N>, Error> {
self.raw_neighbors(component_id, petgraph::Direction::Incoming)
}
pub fn raw_successors(&self, component_id: u64) -> Result<RawNeighbors<'_, N>, Error> {
self.raw_neighbors(component_id, petgraph::Direction::Outgoing)
}
fn raw_neighbors(
&self,
component_id: u64,
direction: petgraph::Direction,
) -> Result<RawNeighbors<'_, N>, Error> {
self.node_indices
.get(&component_id)
.map(|&index| RawNeighbors {
graph: &self.graph,
iter: self.graph.neighbors_directed(index, direction),
})
.ok_or_else(|| {
Error::component_not_found(format!("Component with id {component_id} not found."))
})
}
pub fn predecessors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
self.collect_effective_neighbors(component_id, petgraph::Direction::Incoming)
}
pub fn successors(&self, component_id: u64) -> Result<Neighbors<'_, N>, Error> {
self.collect_effective_neighbors(component_id, petgraph::Direction::Outgoing)
}
fn collect_effective_neighbors(
&self,
component_id: u64,
direction: petgraph::Direction,
) -> Result<Neighbors<'_, N>, Error> {
let start = *self.node_indices.get(&component_id).ok_or_else(|| {
Error::component_not_found(format!("Component with id {component_id} not found."))
})?;
let mut queue: VecDeque<NodeIndex> =
self.graph.neighbors_directed(start, direction).collect();
let mut visited: HashSet<NodeIndex> = HashSet::new();
let mut result: Vec<&N> = Vec::new();
while let Some(idx) = queue.pop_front() {
if !visited.insert(idx) {
continue;
}
let node = &self.graph[idx];
if node.category().is_passthrough() {
queue.extend(self.graph.neighbors_directed(idx, direction));
} else {
result.push(node);
}
}
Ok(Neighbors {
iter: result.into_iter(),
})
}
pub(crate) fn siblings_from_predecessors(
&self,
component_id: u64,
) -> Result<Siblings<'_, N>, Error> {
Ok(Siblings::new(
component_id,
self.predecessors(component_id)?
.map(|x| self.successors(x.component_id()))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten(),
))
}
pub(crate) fn siblings_from_successors(
&self,
component_id: u64,
) -> Result<Siblings<'_, N>, Error> {
Ok(Siblings::new(
component_id,
self.successors(component_id)?
.map(|x| self.predecessors(x.component_id()))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten(),
))
}
pub(crate) fn find_all(
&self,
from: u64,
mut pred: impl FnMut(&N) -> bool,
direction: petgraph::Direction,
follow_after_match: bool,
) -> Result<BTreeSet<u64>, Error> {
let index = self.node_indices.get(&from).ok_or_else(|| {
Error::component_not_found(format!("Component with id {from} not found."))
})?;
let mut stack = vec![*index];
let mut found = BTreeSet::new();
while let Some(index) = stack.pop() {
let node = &self.graph[index];
if !node.category().is_passthrough() && pred(node) {
found.insert(node.component_id());
if !follow_after_match {
continue;
}
}
let neighbors = self.graph.neighbors_directed(index, direction);
stack.extend(neighbors);
}
Ok(found)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ComponentCategory;
use crate::ComponentGraphConfig;
use crate::InverterType;
use crate::component_category::BatteryType;
use crate::component_category::CategoryPredicates;
use crate::error::Error;
use crate::graph::test_utils::ComponentGraphBuilder;
use crate::graph::test_utils::{TestComponent, TestConnection};
fn nodes_and_edges() -> (Vec<TestComponent>, Vec<TestConnection>) {
let components = vec![
TestComponent::new(6, ComponentCategory::Meter),
TestComponent::new(1, ComponentCategory::GridConnectionPoint),
TestComponent::new(7, ComponentCategory::Inverter(InverterType::Battery)),
TestComponent::new(3, ComponentCategory::Meter),
TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon)),
TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery)),
TestComponent::new(2, ComponentCategory::Meter),
];
let connections = vec![
TestConnection::new(3, 4),
TestConnection::new(1, 2),
TestConnection::new(7, 8),
TestConnection::new(4, 5),
TestConnection::new(2, 3),
TestConnection::new(6, 7),
TestConnection::new(2, 6),
];
(components, connections)
}
#[test]
fn test_component() -> Result<(), Error> {
let config = ComponentGraphConfig::default();
let (components, connections) = nodes_and_edges();
let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
assert_eq!(
graph.component(1),
Ok(&TestComponent::new(
1,
ComponentCategory::GridConnectionPoint
))
);
assert_eq!(
graph.component(5),
Ok(&TestComponent::new(
5,
ComponentCategory::Battery(BatteryType::Unspecified)
))
);
assert_eq!(
graph.component(9),
Err(Error::component_not_found("Component with id 9 not found."))
);
Ok(())
}
#[test]
fn test_components() -> Result<(), Error> {
let config = ComponentGraphConfig::default();
let (components, connections) = nodes_and_edges();
let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
assert!(graph.components().eq(&components));
assert!(graph.components().filter(|x| x.is_battery()).eq(&[
TestComponent::new(5, ComponentCategory::Battery(BatteryType::Unspecified)),
TestComponent::new(8, ComponentCategory::Battery(BatteryType::LiIon))
]));
Ok(())
}
#[test]
fn test_connections() -> Result<(), Error> {
let config = ComponentGraphConfig::default();
let (components, connections) = nodes_and_edges();
let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
assert!(graph.connections().eq(&connections));
assert!(
graph
.connections()
.filter(|x| x.source() == 2)
.eq(&[TestConnection::new(2, 3), TestConnection::new(2, 6)])
);
Ok(())
}
#[test]
fn test_neighbors() -> Result<(), Error> {
let config = ComponentGraphConfig::default();
let (components, connections) = nodes_and_edges();
let graph = ComponentGraph::try_new(components.clone(), connections.clone(), config)?;
assert!(graph.predecessors(1).is_ok_and(|x| x.eq(&[])));
assert!(
graph
.predecessors(3)
.is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
);
assert!(
graph
.successors(1)
.is_ok_and(|x| x.eq(&[TestComponent::new(2, ComponentCategory::Meter)]))
);
assert!(graph.successors(2).is_ok_and(|x| {
x.eq(&[
TestComponent::new(6, ComponentCategory::Meter),
TestComponent::new(3, ComponentCategory::Meter),
])
}));
assert!(graph.successors(5).is_ok_and(|x| x.eq(&[])));
assert!(
graph
.predecessors(32)
.is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
);
assert!(
graph
.successors(32)
.is_err_and(|e| e == Error::component_not_found("Component with id 32 not found."))
);
Ok(())
}
#[test]
fn test_siblings() -> Result<(), Error> {
let mut builder = ComponentGraphBuilder::new();
let grid = builder.grid();
let grid_meter = builder.meter();
builder.connect(grid, grid_meter);
assert_eq!(grid_meter.component_id(), 1);
let meter_bat_chain = builder.meter_bat_chain(3, 2);
builder.connect(grid_meter, meter_bat_chain);
assert_eq!(meter_bat_chain.component_id(), 2);
let graph = builder.build(None)?;
assert_eq!(
graph
.siblings_from_predecessors(3)
.unwrap()
.collect::<Vec<_>>(),
[
&TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
&TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
]
);
assert_eq!(
graph
.siblings_from_successors(3)
.unwrap()
.collect::<Vec<_>>(),
[
&TestComponent::new(5, ComponentCategory::Inverter(InverterType::Battery)),
&TestComponent::new(4, ComponentCategory::Inverter(InverterType::Battery))
]
);
assert_eq!(
graph
.siblings_from_successors(6)
.unwrap()
.collect::<Vec<_>>(),
Vec::<&TestComponent>::new()
);
assert_eq!(
graph
.siblings_from_predecessors(6)
.unwrap()
.collect::<Vec<_>>(),
[&TestComponent::new(
7,
ComponentCategory::Battery(BatteryType::LiIon)
)]
);
let dangling_meter = builder.meter();
builder.connect(grid_meter, dangling_meter);
assert_eq!(dangling_meter.component_id(), 8);
let dangling_meter = builder.meter();
builder.connect(grid_meter, dangling_meter);
assert_eq!(dangling_meter.component_id(), 9);
let graph = builder.build(None)?;
assert_eq!(
graph
.siblings_from_predecessors(8)
.unwrap()
.collect::<Vec<_>>(),
[
&TestComponent::new(9, ComponentCategory::Meter),
&TestComponent::new(2, ComponentCategory::Meter),
]
);
Ok(())
}
#[test]
fn test_raw_neighbors_includes_passthroughs() -> Result<(), Error> {
let mut builder = ComponentGraphBuilder::new();
let grid = builder.grid();
let pt = builder.power_transformer();
let meter = builder.meter();
let inverter = builder.battery_inverter();
let battery = builder.battery();
builder.connect(grid, pt);
builder.connect(pt, meter);
builder.connect(meter, inverter);
builder.connect(inverter, battery);
let graph = builder.build(None)?;
let raw_preds: Vec<u64> = graph
.raw_predecessors(meter.component_id())?
.map(|n| n.component_id())
.collect();
assert_eq!(raw_preds, vec![pt.component_id()]);
let raw_succs: Vec<u64> = graph
.raw_successors(grid.component_id())?
.map(|n| n.component_id())
.collect();
assert_eq!(raw_succs, vec![pt.component_id()]);
let preds: Vec<u64> = graph
.predecessors(meter.component_id())?
.map(|n| n.component_id())
.collect();
assert_eq!(preds, vec![grid.component_id()]);
let succs: Vec<u64> = graph
.successors(grid.component_id())?
.map(|n| n.component_id())
.collect();
assert_eq!(succs, vec![meter.component_id()]);
assert!(graph.raw_predecessors(999).is_err());
assert!(graph.raw_successors(999).is_err());
let _ = (battery, inverter);
Ok(())
}
#[test]
fn test_find_all_skips_passthroughs() -> Result<(), Error> {
let mut builder = ComponentGraphBuilder::new();
let grid = builder.grid();
let pt = builder.power_transformer();
let meter = builder.meter();
builder.connect(grid, pt);
builder.connect(pt, meter);
let graph = builder.build(None)?;
let found = graph.find_all(
grid.component_id(),
|_| true,
petgraph::Direction::Outgoing,
true,
)?;
assert_eq!(
found,
BTreeSet::from([grid.component_id(), meter.component_id()])
);
let found = graph.find_all(
grid.component_id(),
|n| n.category() == ComponentCategory::PowerTransformer,
petgraph::Direction::Outgoing,
true,
)?;
assert!(found.is_empty());
Ok(())
}
#[test]
fn test_find_all() -> Result<(), Error> {
let (components, connections) = nodes_and_edges();
let graph = ComponentGraph::try_new(
components.clone(),
connections.clone(),
ComponentGraphConfig::default(),
)?;
let found = graph.find_all(
graph.root_id,
|x| x.is_meter(),
petgraph::Direction::Outgoing,
false,
)?;
assert_eq!(found, [2].iter().cloned().collect());
let found = graph.find_all(
graph.root_id,
|x| x.is_meter(),
petgraph::Direction::Outgoing,
true,
)?;
assert_eq!(found, [2, 3, 6].iter().cloned().collect());
let found = graph.find_all(
graph.root_id,
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
petgraph::Direction::Outgoing,
true,
)?;
assert_eq!(found, [2, 4, 5, 7, 8].iter().cloned().collect());
let found = graph.find_all(
6,
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
petgraph::Direction::Outgoing,
true,
)?;
assert_eq!(found, [7, 8].iter().cloned().collect());
let found = graph.find_all(
graph.root_id,
|x| !x.is_grid() && !graph.is_component_meter(x.component_id()).unwrap_or(false),
petgraph::Direction::Outgoing,
false,
)?;
assert_eq!(found, [2].iter().cloned().collect());
let found = graph.find_all(
graph.root_id,
|_| true,
petgraph::Direction::Outgoing,
false,
)?;
assert_eq!(found, [1].iter().cloned().collect());
let found = graph.find_all(3, |_| true, petgraph::Direction::Outgoing, true)?;
assert_eq!(found, [3, 4, 5].iter().cloned().collect());
Ok(())
}
}