use itertools::Itertools;
use portgraph::{NodeIndex, PortOffset};
use thiserror::Error;
use crate::types::TypeRow;
use super::dataflow::{DataflowOpTrait, DataflowParent};
use super::{impl_validate_op, BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp};
#[non_exhaustive]
pub struct OpValidityFlags {
pub allowed_children: OpTag,
pub allowed_first_child: OpTag,
pub allowed_second_child: OpTag,
pub requires_children: bool,
pub requires_dag: bool,
pub edge_check: Option<fn(ChildrenEdgeData) -> Result<(), EdgeValidationError>>,
}
impl Default for OpValidityFlags {
fn default() -> Self {
Self {
allowed_children: OpTag::None,
allowed_first_child: OpTag::Any,
allowed_second_child: OpTag::Any,
requires_children: false,
requires_dag: false,
edge_check: None,
}
}
}
impl ValidateOp for super::Module {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::ModuleOp,
requires_children: false,
..Default::default()
}
}
}
impl ValidateOp for super::Conditional {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::Case,
requires_children: true,
requires_dag: false,
..Default::default()
}
}
fn validate_op_children<'a>(
&self,
children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
let children = children.collect_vec();
if self.sum_rows.len() != children.len() {
return Err(ChildrenValidationError::InvalidConditionalSum {
child: children[0].0, expected_count: children.len(),
actual_sum_rows: self.sum_rows.clone(),
});
}
for (i, (child, optype)) in children.into_iter().enumerate() {
let case_op = optype
.as_case()
.expect("Child check should have already checked valid ops.");
let sig = &case_op.inner_signature();
if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
return Err(ChildrenValidationError::ConditionalCaseSignature {
child,
optype: optype.clone(),
});
}
}
Ok(())
}
}
impl ValidateOp for super::CFG {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::ControlFlowChild,
allowed_first_child: OpTag::DataflowBlock,
allowed_second_child: OpTag::BasicBlockExit,
requires_children: true,
requires_dag: false,
edge_check: Some(validate_cfg_edge),
..Default::default()
}
}
fn validate_op_children<'a>(
&self,
mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
let (entry, entry_op) = children.next().unwrap();
let (exit, exit_op) = children.next().unwrap();
let entry_op = entry_op
.as_dataflow_block()
.expect("Child check should have already checked valid ops.");
let exit_op = exit_op
.as_exit_block()
.expect("Child check should have already checked valid ops.");
let sig = self.signature();
if entry_op.inner_signature().input() != sig.input() {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: entry,
actual: entry_op.inner_signature().input().clone(),
expected: sig.input().clone(),
node_desc: "BasicBlock Input",
container_desc: "CFG",
});
}
if &exit_op.cfg_outputs != sig.output() {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: exit,
actual: exit_op.cfg_outputs.clone(),
expected: sig.output().clone(),
node_desc: "BasicBlockExit Output",
container_desc: "CFG",
});
}
for (child, optype) in children {
if optype.tag() == OpTag::BasicBlockExit {
return Err(ChildrenValidationError::InternalExitChildren { child });
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Error)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum ChildrenValidationError {
#[error("Exit basic blocks are only allowed as the second child in a CFG graph")]
InternalExitChildren { child: NodeIndex },
#[error("A {optype} operation is only allowed as a {expected_position} child")]
InternalIOChildren {
child: NodeIndex,
optype: OpType,
expected_position: &'static str,
},
#[error("The {node_desc} node of a {container_desc} has a signature of {actual}, which differs from the expected type row {expected}")]
IOSignatureMismatch {
child: NodeIndex,
actual: TypeRow,
expected: TypeRow,
node_desc: &'static str,
container_desc: &'static str,
},
#[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())]
ConditionalCaseSignature { child: NodeIndex, optype: OpType },
#[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}",
actual_sum_rows.len())]
InvalidConditionalSum {
child: NodeIndex,
expected_count: usize,
actual_sum_rows: Vec<TypeRow>,
},
}
impl ChildrenValidationError {
pub fn child(&self) -> NodeIndex {
match self {
ChildrenValidationError::InternalIOChildren { child, .. } => *child,
ChildrenValidationError::InternalExitChildren { child, .. } => *child,
ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child,
ChildrenValidationError::IOSignatureMismatch { child, .. } => *child,
ChildrenValidationError::InvalidConditionalSum { child, .. } => *child,
}
}
}
#[derive(Debug, Clone, PartialEq, Error)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum EdgeValidationError {
#[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
source_ty = source_types.clone().unwrap_or_default(),
)]
CFGEdgeSignatureMismatch {
edge: ChildrenEdgeData,
source_types: Option<TypeRow>,
target_types: TypeRow,
},
}
impl EdgeValidationError {
pub fn edge(&self) -> &ChildrenEdgeData {
match self {
EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChildrenEdgeData {
pub source: NodeIndex,
pub target: NodeIndex,
pub source_op: OpType,
pub target_op: OpType,
pub source_port: PortOffset,
pub target_port: PortOffset,
}
impl<T: DataflowParent> ValidateOp for T {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::DataflowChild,
allowed_first_child: OpTag::Input,
allowed_second_child: OpTag::Output,
requires_children: true,
requires_dag: true,
..Default::default()
}
}
fn validate_op_children<'a>(
&self,
children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
let sig = self.inner_signature();
validate_io_nodes(&sig.input, &sig.output, "DataflowParent", children)
}
}
fn validate_io_nodes<'a>(
expected_input: &TypeRow,
expected_output: &TypeRow,
container_desc: &'static str,
mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
let (first, first_optype) = children.next().unwrap();
let (second, second_optype) = children.next().unwrap();
let first_sig = first_optype.dataflow_signature().unwrap_or_default();
if &first_sig.output != expected_input {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: first,
actual: first_sig.into_owned().output,
expected: expected_input.clone(),
node_desc: "Input",
container_desc,
});
}
let second_sig = second_optype.dataflow_signature().unwrap_or_default();
if &second_sig.input != expected_output {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: second,
actual: second_sig.into_owned().input,
expected: expected_output.clone(),
node_desc: "Output",
container_desc,
});
}
for (child, optype) in children {
match optype.tag() {
OpTag::Input => {
return Err(ChildrenValidationError::InternalIOChildren {
child,
optype: optype.clone(),
expected_position: "first",
})
}
OpTag::Output => {
return Err(ChildrenValidationError::InternalIOChildren {
child,
optype: optype.clone(),
expected_position: "second",
})
}
_ => {}
}
}
Ok(())
}
fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> {
let source = &edge
.source_op
.as_dataflow_block()
.expect("CFG sibling graphs can only contain basic block operations.");
let target_input = match &edge.target_op {
OpType::DataflowBlock(dfb) => dfb.dataflow_input(),
OpType::ExitBlock(exit) => exit.dataflow_input(),
_ => panic!("CFG sibling graphs can only contain basic block operations."),
};
let source_types = source.successor_input(edge.source_port.index());
if source_types.as_ref() != Some(target_input) {
let target_types = target_input.clone();
return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
edge,
source_types,
target_types,
});
}
Ok(())
}
#[cfg(test)]
mod test {
use crate::extension::prelude::{usize_t, Noop};
use crate::ops;
use crate::ops::dataflow::IOTrait;
use cool_asserts::assert_matches;
use super::*;
#[test]
fn test_validate_io_nodes() {
let in_types: TypeRow = vec![usize_t()].into();
let out_types: TypeRow = vec![usize_t(), usize_t()].into();
let input_node: OpType = ops::Input::new(in_types.clone()).into();
let output_node = ops::Output::new(out_types.clone()).into();
let leaf_node = Noop(usize_t()).into();
let children = vec![
(0, &input_node),
(1, &output_node),
(2, &leaf_node),
(3, &leaf_node),
];
assert_eq!(
validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
Ok(())
);
assert_matches!(
validate_io_nodes(&out_types, &out_types, "test", make_iter(&children)),
Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 0
);
assert_matches!(
validate_io_nodes(&in_types, &in_types, "test", make_iter(&children)),
Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 1
);
let children = vec![
(0, &input_node),
(1, &output_node),
(42, &leaf_node),
(2, &leaf_node),
(3, &output_node),
];
assert_matches!(
validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
Err(ChildrenValidationError::InternalIOChildren { child, .. }) if child.index() == 3
);
}
fn make_iter<'a>(
children: &'a [(usize, &OpType)],
) -> impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)> {
children.iter().map(|(n, op)| (NodeIndex::new(*n), *op))
}
}
use super::{
AliasDecl, AliasDefn, Call, CallIndirect, Const, ExtensionOp, FuncDecl, Input, LoadConstant,
LoadFunction, OpaqueOp, Output, Tag,
};
impl_validate_op!(FuncDecl);
impl_validate_op!(AliasDecl);
impl_validate_op!(AliasDefn);
impl_validate_op!(Input);
impl_validate_op!(Output);
impl_validate_op!(Const);
impl_validate_op!(Call);
impl_validate_op!(LoadConstant);
impl_validate_op!(LoadFunction);
impl_validate_op!(CallIndirect);
impl_validate_op!(ExtensionOp);
impl_validate_op!(OpaqueOp);
impl_validate_op!(Tag);
impl_validate_op!(ExitBlock);