use crate::model::descriptor::{FlattenDataFlowDescriptor, InputDescriptor, OutputDescriptor};
use crate::types::{NodeId, PortId};
use crate::zferror;
use crate::zfresult::ErrorKind;
use crate::Result as ZFResult;
use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::Graph;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::sync::Arc;
pub(crate) struct DataFlowValidator {
graph_checker: Graph<(NodeId, NodeKind), (PortId, PortId, EdgeIndex)>,
node_checker: Graph<PortUniqueId, ()>,
input_indexes: HashSet<NodeIndex>,
output_indexes: HashSet<NodeIndex>,
map_id_to_node_checker_idx: HashMap<PortUniqueId, NodeIndex>,
map_id_to_graph_checker_idx: HashMap<NodeId, (NodeKind, NodeIndex)>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum PortKind {
Input,
Output,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum NodeKind {
Source,
Sink,
Operator,
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
struct PortUniqueId {
node_id: NodeId,
port_id: PortId,
kind: PortKind,
}
impl TryFrom<&FlattenDataFlowDescriptor> for DataFlowValidator {
type Error = crate::zfresult::Error;
fn try_from(descriptor: &FlattenDataFlowDescriptor) -> Result<Self, Self::Error> {
let mut validator = DataFlowValidator::new();
let mut all_node_ids = vec![];
all_node_ids.extend(descriptor.sources.iter().map(|s| s.id.clone()));
all_node_ids.extend(descriptor.operators.iter().map(|s| s.id.clone()));
all_node_ids.extend(descriptor.sinks.iter().map(|s| s.id.clone()));
descriptor
.sources
.iter()
.try_for_each(|source| validator.try_add_source(source.id.clone(), &source.outputs))?;
descriptor.operators.iter().try_for_each(|operator| {
validator.try_add_operator(operator.id.clone(), &operator.inputs, &operator.outputs)
})?;
descriptor
.sinks
.iter()
.try_for_each(|sink| validator.try_add_sink(sink.id.clone(), &sink.inputs))?;
descriptor
.links
.iter()
.try_for_each(|link| validator.try_add_link(&link.from, &link.to))?;
Ok(validator)
}
}
impl DataFlowValidator {
pub(crate) fn new() -> Self {
Self {
graph_checker: Graph::new(),
input_indexes: HashSet::new(),
output_indexes: HashSet::new(),
map_id_to_node_checker_idx: HashMap::new(),
map_id_to_graph_checker_idx: HashMap::new(),
node_checker: Graph::new(),
}
}
fn try_add_id(&mut self, node_kind: NodeKind, node_id: NodeId) -> ZFResult<()> {
let graph_checker_idx = self.graph_checker.add_node((node_id.clone(), node_kind));
if self
.map_id_to_graph_checker_idx
.insert(node_id.clone(), (node_kind, graph_checker_idx))
.is_some()
{
return Err(zferror!(ErrorKind::DuplicatedNodeId(node_id), "Duplicate node").into());
}
Ok(())
}
fn try_add_node(
&mut self,
node_id: NodeId,
port: PortId,
port_kind: PortKind,
) -> ZFResult<NodeIndex> {
let id = PortUniqueId {
node_id: node_id.clone(),
port_id: port.clone(),
kind: port_kind,
};
let node_checker_idx = self.node_checker.add_node(id.clone());
if self
.map_id_to_node_checker_idx
.insert(id, node_checker_idx)
.is_some()
{
return Err(zferror!(ErrorKind::DuplicatedPort((
Arc::clone(&node_id),
Arc::clone(&port)
)))
.into());
}
Ok(node_checker_idx)
}
pub(crate) fn try_add_input(&mut self, node_id: NodeId, input: PortId) -> ZFResult<()> {
let node_checker_idx = self.try_add_node(node_id, input, PortKind::Input)?;
self.input_indexes.insert(node_checker_idx);
Ok(())
}
pub(crate) fn try_add_output(&mut self, node_id: NodeId, output: PortId) -> ZFResult<()> {
let node_checker_idx = self.try_add_node(node_id, output, PortKind::Output)?;
self.output_indexes.insert(node_checker_idx);
Ok(())
}
pub(crate) fn try_add_source(&mut self, node_id: NodeId, outputs: &[PortId]) -> ZFResult<()> {
self.try_add_id(NodeKind::Source, node_id.clone())?;
outputs
.iter()
.try_for_each(|output| self.try_add_output(node_id.clone(), output.clone()))
}
pub(crate) fn try_add_sink(&mut self, node_id: NodeId, inputs: &[PortId]) -> ZFResult<()> {
self.try_add_id(NodeKind::Sink, node_id.clone())?;
inputs
.iter()
.try_for_each(|input| self.try_add_input(node_id.clone(), input.clone()))
}
pub(crate) fn try_add_operator(
&mut self,
node_id: NodeId,
inputs: &[PortId],
outputs: &[PortId],
) -> ZFResult<()> {
self.try_add_id(NodeKind::Operator, node_id.clone())?;
inputs
.iter()
.try_for_each(|input| self.try_add_input(node_id.clone(), input.clone()))?;
outputs
.iter()
.try_for_each(|output| self.try_add_output(node_id.clone(), output.clone()))
}
pub(crate) fn try_add_link(
&mut self,
from: &OutputDescriptor,
to: &InputDescriptor,
) -> ZFResult<()> {
log::debug!("Looking for node < {} >…", &from.node);
let (_, from_graph_checker_idx) = self
.map_id_to_graph_checker_idx
.get(&from.node)
.ok_or_else(|| zferror!(ErrorKind::NodeNotFound(from.node.clone())))?;
log::debug!("Looking for node < {} >… OK.", &from.node);
log::debug!("Looking for node < {} >…", &to.node);
let (_, to_graph_checker_idx) = self
.map_id_to_graph_checker_idx
.get(&to.node)
.ok_or_else(|| zferror!(ErrorKind::NodeNotFound(to.node.clone())))?;
log::debug!("Looking for node < {} >… OK.", &to.node);
let from_id = PortUniqueId {
node_id: from.node.clone(),
port_id: from.output.clone(),
kind: PortKind::Output,
};
let to_id = PortUniqueId {
node_id: to.node.clone(),
port_id: to.input.clone(),
kind: PortKind::Input,
};
let edge_idx = self.graph_checker.add_edge(
*from_graph_checker_idx,
*to_graph_checker_idx,
(from.output.clone(), to.input.clone(), EdgeIndex::default()),
);
log::trace!("FromId {from_id:?} ToId: {to_id:?}");
let mut edge_weight = self
.graph_checker
.edge_weight_mut(edge_idx)
.ok_or_else(|| {
zferror!(
ErrorKind::NotFound,
"Link with id {edge_idx:?}, between {from:?} => {to:?} not found"
)
})?;
edge_weight.2 = edge_idx;
let from_node_checker_idx =
self.map_id_to_node_checker_idx
.get(&from_id)
.ok_or_else(|| {
zferror!(ErrorKind::PortNotFound((
from.node.clone(),
from.output.clone()
)))
})?;
let to_node_checker_idx = self.map_id_to_node_checker_idx.get(&to_id).ok_or_else(|| {
zferror!(ErrorKind::PortNotFound((to.node.clone(), to.input.clone())))
})?;
self.node_checker
.add_edge(*from_node_checker_idx, *to_node_checker_idx, ());
Ok(())
}
pub(crate) fn validate_ports(&self) -> ZFResult<()> {
self.input_indexes.iter().try_for_each(|idx| {
match self
.node_checker
.edges_directed(*idx, petgraph::EdgeDirection::Incoming)
.count()
{
0 => {
let port = self.node_checker.node_weight(*idx).unwrap();
Err(zferror!(ErrorKind::PortNotConnected((
port.node_id.clone(),
port.port_id.clone(),
))))
}
_ => Ok(()),
}
})?;
self.output_indexes.iter().try_for_each(|idx| {
match self
.node_checker
.edges_directed(*idx, petgraph::EdgeDirection::Outgoing)
.count()
{
0 => {
let port = self.node_checker.node_weight(*idx).unwrap();
Err(zferror!(ErrorKind::PortNotConnected((
port.node_id.clone(),
port.port_id.clone()
)))
.into())
}
_ => Ok(()),
}
})
}
}