use crate::onnx::{ModelProto, NodeProto, TensorProto, ValueInfoProto};
use crate::utils::{DataTypeError, Shape};
use std::borrow::Cow;
use std::fmt::Debug;
use std::hash::Hash;
use std::ptr;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
#[derive(Clone)]
pub struct OperatorDefinition<'model> {
pub(crate) proto: Cow<'model, NodeProto>,
pub(crate) output_shapes: Vec<Shape>,
}
impl<'model> OperatorDefinition<'model> {
pub fn from(
node: Cow<'model, NodeProto>,
value_shapes: &HashMap<&'model str, Shape>,
) -> Result<OperatorDefinition<'model>, IrError> {
let mut output_shapes: Vec<Shape> = Vec::with_capacity(node.get_output().len());
for output_name in node.get_output() {
if !value_shapes.contains_key(output_name.as_str()) {
return Err(IrError::OutputNodeNotFound(output_name.to_string()));
}
output_shapes.push(value_shapes[&output_name.as_str()].clone());
}
Ok(OperatorDefinition {
proto: node,
output_shapes,
})
}
}
#[derive(Clone)]
pub enum NodeDefinition<'model> {
Operator(Box<OperatorDefinition<'model>>),
Tensor(Box<Cow<'model, TensorProto>>),
Input(&'model ValueInfoProto),
Outputs { names: Vec<String> },
Missing, }
static MISSING_OPTIONAL_INPUT: NodeDefinition<'static> = NodeDefinition::Missing;
#[derive(Clone)]
pub struct Input<'model> {
pub source_node: Arc<Node<'model>>,
pub output_index: usize,
}
pub struct Node<'model> {
pub definition: NodeDefinition<'model>,
pub inputs: Vec<Input<'model>>,
}
#[derive(Debug, Error)]
pub enum IrError {
#[error("output node for output {0} not found")]
OutputNodeNotFound(String),
#[error("could not find node corresponding to input {input_name} of node {target_node_name}")]
InputNodeNotFound {
target_node_name: String,
input_name: String,
},
#[error("issue with data types: {0}")]
Type(#[from] DataTypeError),
}
impl<'m> NodeDefinition<'m> {
pub fn get_name(&self) -> Cow<'_, str> {
match self {
NodeDefinition::Operator(op_def) => Cow::from(&op_def.proto.get_output()[0]),
NodeDefinition::Tensor(t) => Cow::from(t.get_name()),
NodeDefinition::Input(i) => Cow::from(i.get_name()),
NodeDefinition::Outputs { .. } => Cow::from(" "),
NodeDefinition::Missing => Cow::from(""),
}
}
}
impl NodeProto {
fn unique_name(&self) -> String {
self.get_output()[0].clone()
}
}
impl<'model> Node<'model> {
pub fn new(variant: NodeDefinition<'model>) -> Node<'model> {
Node {
definition: variant,
inputs: vec![],
}
}
pub fn definition(&self) -> &NodeDefinition<'model> {
&self.definition
}
pub fn from_node<'a>(
model: &'model ModelProto,
node: Cow<'model, NodeProto>,
value_shapes: &HashMap<&'model str, Shape>,
node_definitions_by_output: &'a HashMap<String, NodeDefinition<'model>>,
nodes_by_unique_name: &mut HashMap<String, Arc<Node<'model>>>,
) -> Result<Arc<Node<'model>>, IrError> {
let node_name = node.unique_name();
if nodes_by_unique_name.contains_key(&node_name) {
let n = nodes_by_unique_name.get(&node_name).unwrap();
return Ok(n.clone());
}
let inputs: Result<Vec<Input<'model>>, IrError> = node
.get_input()
.iter()
.map(|input_name: &'model String| {
let my_input_name = input_name.clone();
let source_node_definition = node_definitions_by_output
.get(&my_input_name)
.unwrap_or(&MISSING_OPTIONAL_INPUT);
Ok(match source_node_definition {
NodeDefinition::Operator(source_node_proto) => Input {
source_node: Node::from_node(
model,
source_node_proto.proto.clone(),
value_shapes,
node_definitions_by_output,
nodes_by_unique_name,
)?,
output_index: source_node_proto
.proto
.get_output()
.iter()
.position(|s| s == input_name)
.ok_or_else(|| IrError::OutputNodeNotFound(input_name.to_string()))?,
},
_ => {
let source_name = source_node_definition.get_name().to_string();
Input {
output_index: 0,
source_node: match nodes_by_unique_name.get(&source_name) {
Some(node) => node.clone(),
None => {
let node = Arc::new(Node::new(source_node_definition.clone()));
nodes_by_unique_name.insert(source_name, node.clone());
node
}
},
}
}
})
})
.collect();
let translated = Arc::new(Node {
definition: NodeDefinition::Operator(Box::new(OperatorDefinition::from(
node.clone(),
value_shapes,
)?)),
inputs: inputs?,
});
nodes_by_unique_name.insert(node.unique_name(), translated.clone());
Ok(translated)
}
pub fn from_model(
model: &'model ModelProto,
outputs: Option<&[String]>,
) -> Result<Arc<Node<'model>>, IrError> {
let mut value_shapes: HashMap<&'model str, Shape> = HashMap::new();
for vi in model.get_graph().get_value_info() {
value_shapes.insert(vi.get_name(), vi.get_shape()?);
}
for vi in model.get_graph().get_output() {
let output_name = vi.get_name();
if !output_name.is_empty() {
value_shapes.insert(output_name, vi.get_shape()?);
}
}
let mut node_definitions_by_output = HashMap::<String, NodeDefinition<'model>>::new();
for node in model.get_graph().get_node().iter() {
let node_def = NodeDefinition::Operator(Box::new(OperatorDefinition::from(
Cow::Borrowed(node),
&value_shapes,
)?));
for output in node.get_output() {
if !output.is_empty() {
node_definitions_by_output.insert(output.to_string(), node_def.clone());
}
}
}
for initializer in model.get_graph().get_initializer().iter() {
log::debug!("Initializer {}", initializer.get_name());
node_definitions_by_output.insert(
initializer.get_name().to_string(),
NodeDefinition::Tensor(Box::new(Cow::Borrowed(initializer))),
);
}
let output_names: Vec<String> = match outputs {
Some(outputs) => outputs.to_vec(),
None => model
.get_graph()
.get_output()
.iter()
.map(|x| x.get_name().to_string())
.collect(),
};
for input in model.get_graph().get_input().iter() {
if !node_definitions_by_output.contains_key(input.get_name()) {
log::debug!("Input {}", input.get_name());
node_definitions_by_output
.insert(input.get_name().to_string(), NodeDefinition::Input(input));
} else {
log::info!(
"Skipping input definition {}: already defined",
input.get_name()
);
}
}
let mut nodes_by_name = HashMap::new();
let output_nodes: Result<Vec<Input<'model>>, IrError> = output_names
.iter()
.map(|output_name| {
let output_node = model
.get_graph()
.get_node()
.iter()
.find(|x| -> bool { x.get_output().contains(output_name) })
.ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?;
let source_node = Node::<'model>::from_node(
model,
Cow::Borrowed(output_node),
&value_shapes,
&node_definitions_by_output,
&mut nodes_by_name,
)?;
let output_index = output_node
.get_output()
.iter()
.position(|s| s == output_name)
.ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?;
Ok(Input {
source_node,
output_index,
})
})
.collect();
Ok(Arc::new(Node {
definition: NodeDefinition::Outputs {
names: output_names,
},
inputs: output_nodes?,
}))
}
}
impl<'model> Debug for NodeDefinition<'model> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NodeDefinition::Operator(def) => {
write!(
f,
"op: {} ({})",
def.proto.get_name(),
def.proto.get_op_type()
)
}
NodeDefinition::Tensor(def) => write!(f, "tensor {}", def.get_name()),
NodeDefinition::Input(def) => write!(f, "input {}", def.get_name()),
NodeDefinition::Outputs { .. } => write!(f, "outputs"),
NodeDefinition::Missing => write!(f, "missing (optional)"),
}
}
}
#[derive(Clone)]
pub struct NodeIdentifier<'model>(Arc<Node<'model>>);
impl<'model> Hash for NodeIdentifier<'model> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
ptr::hash(Arc::as_ptr(&self.0), state)
}
}
impl<'model> PartialEq for NodeIdentifier<'model> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl<'model> Eq for NodeIdentifier<'model> {}
impl<'model> Node<'model> {
pub fn identifier(self: &Arc<Self>) -> NodeIdentifier<'model> {
NodeIdentifier(self.clone())
}
}