use super::prelude::*;
use super::Result;
use model::OutletId;
use std::collections::HashMap;
use {Node, Tensor};
#[derive(Debug)]
pub enum Element {
Node(usize),
Edge(usize),
}
#[derive(Debug)]
pub struct Component {
pub elements: Vec<Element>,
pub outputs: Vec<usize>,
}
pub fn connected_components(analyser: &Analyser) -> Result<Vec<Component>> {
let is_edge_const: Vec<bool> = analyser
.edges
.iter()
.map(|e| e.fact.value.is_concrete())
.collect();
let is_node_const: Vec<bool> = analyser
.next_edges
.iter()
.map(|next| next.len() > 0 && next.iter().all(|i| is_edge_const[*i]))
.collect();
let mut components = vec![];
let mut is_node_colored = vec![false; analyser.nodes.len()];
let mut is_edge_colored = vec![false; analyser.edges.len()];
let mut stack = vec![];
macro_rules! process_edges {
($from:ident, $other:expr, $component:expr, $node:expr) => {{
for &edge in &analyser.$from[$node] {
if !is_edge_const[edge] || is_edge_colored[edge] {
continue;
}
is_edge_colored[edge] = true;
$component.elements.push(Element::Edge(edge));
let target = $other(edge);
if target.is_none() {
continue;
}
if !is_node_const[target.unwrap()] {
$component.outputs.push(edge);
} else {
stack.push(target.unwrap());
}
}
}};
};
for (node, &is_const) in is_node_const.iter().enumerate() {
if is_const && !is_node_colored[node] {
let mut component = Component {
elements: vec![],
outputs: vec![],
};
stack.push(node);
while let Some(node) = stack.pop() {
if !is_node_const[node] || is_node_colored[node] {
continue;
}
is_node_colored[node] = true;
component.elements.push(Element::Node(node));
process_edges!(
prev_edges,
|e: usize| -> Option<usize> { analyser.edges[e].from.map(|n| n.node) },
component,
node
);
process_edges!(
next_edges,
|e: usize| analyser.edges[e].to_node,
component,
node
);
}
components.push(component);
}
}
Ok(components)
}
fn build_const_node(id: usize, name: String, tensor: Tensor) -> Node {
Node {
id,
name,
op_name: "Const".to_string(),
inputs: vec![],
op: Box::new(::ops::konst::Const::for_tensor(tensor)),
}
}
pub fn propagate_constants(analyser: &mut Analyser) -> Result<()> {
let components: Vec<Component> = connected_components(analyser)?;
info!("Detected {:?} connected components.", components.len());
let mut const_int_nodes = HashMap::new();
for component in components {
for i in component.outputs {
let tensor = analyser.edges[i].fact.value.concretize().unwrap();
let const_node_id: usize = if let Some(tensor) = tensor.clone().take_i32s() {
*const_int_nodes.entry(tensor.clone()).or_insert_with(|| {
let node_id = analyser.nodes.len();
let node_name = format!("generated_{}", node_id).to_string();
let node = build_const_node(node_id, node_name, tensor.into());
analyser.nodes.push(node);
node_id
})
} else {
let node_id = analyser.nodes.len();
let node_name = format!("generated_{}", node_id).to_string();
let node = build_const_node(node_id, node_name, tensor);
analyser.nodes.push(node);
node_id
};
let edge = &mut analyser.edges[i];
let old_node_id = edge.from.unwrap().node;
{
let successors = &mut analyser.next_edges[old_node_id];
let position = successors.iter().position(|&i| i == edge.id).unwrap();
successors.remove(position);
};
{
let predecessors = &mut analyser.nodes[edge.to_node.unwrap()].inputs;
let position = predecessors
.iter()
.position(|outlet| outlet.node == old_node_id)
.unwrap();
predecessors[position] = OutletId::new(const_node_id, 0);
}
edge.from = Some(OutletId::new(const_node_id, 0));
analyser.prev_edges.push(vec![]);
analyser.next_edges.push(vec![edge.id]);
}
}
analyser.reset_plan()?;
Ok(())
}