use std::{
cell::RefCell,
collections::{BTreeMap, BTreeSet, VecDeque},
};
use portgraph::{
portgraph::PortOperation, NodeIndex, PortGraph, PortIndex, PortOffset, SecondaryMap,
};
pub fn untangle_threads<F, G, Map>(
graph: &mut PortGraph,
mut trace: Map,
root: NodeIndex,
mut clone_state: G,
mut rekey: F,
) -> BTreeSet<NodeIndex>
where
F: FnMut(PortIndex, Option<PortIndex>),
G: FnMut(NodeIndex, NodeIndex, &PortGraph),
Map: SecondaryMap<PortIndex, (Vec<usize>, bool)>,
{
let mut all_nodes = graph
.nodes_iter()
.filter(|&n| {
graph.all_ports(n).any(|p| {
let (vec, _) = &trace.get(p);
!vec.is_empty()
})
})
.collect::<BTreeSet<_>>();
if all_nodes.is_empty() {
return [root].into();
}
let mut curr_nodes: VecDeque<_> = graph
.output_links(root)
.flatten()
.map(|p| graph.port_node(p).expect("Invalid port"))
.collect();
let mut visited: BTreeSet<_> = [root].into();
let mut trace_next_in = BTreeMap::new();
while let Some(node) = curr_nodes.pop_front() {
if visited.contains(&node) {
continue;
} else if graph
.input_links(node)
.flatten()
.map(|p| graph.port_node(p).expect("Invalid port"))
.any(|n| {
let p = all_nodes.contains(&n) && !visited.contains(&n);
if p {
}
p
})
{
curr_nodes.push_back(node);
continue;
} else {
visited.insert(node);
}
let mut ins = BTreeMap::new();
for p in graph.inputs(node) {
let vec = trace.get(p).0.as_slice();
assert!(vec.len() <= 1);
let k = vec.first().copied();
ins.entry(k).or_insert_with(Vec::new).push(p);
}
let keys: Vec<_> = ins.keys().copied().collect();
let ins: Vec<_> = ins.into_values().collect();
let outs: Vec<_> = keys
.iter()
.map(|&l| {
graph
.outputs(node)
.filter(|&p| {
let vec = trace.get(p).0.as_slice();
let is_new = trace.get(p).1;
if !is_new {
return true;
}
let Some(l) = l else { return false };
vec.contains(&l)
})
.collect::<Vec<_>>()
})
.collect();
let mut next_nodes = graph
.output_links(node)
.flatten()
.map(|p| graph.port_node(p).expect("Invalid port"))
.filter(|&n| all_nodes.contains(&n))
.collect();
if ins.len() > 1 {
all_nodes.remove(&node);
trace_next_in.clear();
for out_port in graph.outputs(node) {
let in_port = graph.port_link(out_port).expect("Disconnected port");
trace_next_in.insert(out_port, trace.take(in_port));
}
let trace_mut = RefCell::new(&mut trace);
let trace_curr_mut = RefCell::new(&mut trace_next_in);
let clone_state = |old, new, graph: &PortGraph| {
let mut trace = trace_mut.borrow_mut();
let mut trace_curr = trace_curr_mut.borrow_mut();
if visited.contains(&old) {
visited.insert(new);
}
for old_port in graph.all_ports(old) {
let offset = graph.port_offset(old_port).expect("invalid port");
let new_port = graph.port_index(new, offset).expect("invalid offset");
if trace.get(old_port) != &Default::default() {
let old_val = trace.get(old_port).clone();
trace.set(new_port, old_val);
}
if let Some(val) = trace_curr.get(&old_port).cloned() {
trace_curr.insert(new_port, val);
}
}
clone_state(old, new, graph)
};
let new_nodes = split_node(graph, node, &ins, &outs, clone_state, |old, new| {
let mut trace = trace_mut.borrow_mut();
let mut trace_curr = trace_curr_mut.borrow_mut();
trace.rekey(old, new);
if let Some(val) = trace_curr.remove(&old) {
if let Some(new) = new {
trace_curr.insert(new, val);
}
}
rekey(old, new)
});
for out_port in new_nodes.iter().flat_map(|&n| graph.outputs(n)) {
let in_port = graph.port_link(out_port).expect("Disconnected port");
if trace_next_in[&out_port] != Default::default()
|| in_port.index() < trace.capacity()
{
trace.set(in_port, trace_next_in.remove(&out_port).unwrap());
}
}
for (k, n) in keys.iter().zip(new_nodes) {
if k.is_some() {
all_nodes.insert(n);
}
for out_port in graph.outputs(n) {
let in_port = graph.port_link(out_port).expect("Disconnected port");
let (out_trace, out_flag) = trace.get(out_port);
let pos = k.and_then(|k| out_trace.iter().position(|&x| x == k));
let new_out = (
pos.map(|pos| vec![out_trace[pos]]).unwrap_or_default(),
*out_flag,
);
let (in_trace, in_flag) = trace.get(in_port);
let new_in = (
pos.map(|pos| vec![in_trace[pos]]).unwrap_or_default(),
*in_flag,
);
trace.set(out_port, new_out);
trace.set(in_port, new_in);
}
}
}
curr_nodes.append(&mut next_nodes);
}
all_nodes
.iter()
.copied()
.filter(|&n| {
graph
.output_links(n)
.flatten()
.all(|p| trace.get(p).0.is_empty())
})
.collect()
}
fn split_node<F, G>(
graph: &mut PortGraph,
old_node: NodeIndex,
ins: &Vec<Vec<PortIndex>>,
outs: &Vec<Vec<PortIndex>>,
mut clone_state: G,
mut rekey: F,
) -> Vec<NodeIndex>
where
F: FnMut(PortIndex, Option<PortIndex>),
G: FnMut(NodeIndex, NodeIndex, &PortGraph),
{
let n_partitions = ins.len();
let num_in = graph.num_inputs(old_node);
let num_out = graph.num_outputs(old_node);
assert_eq!(n_partitions, outs.len());
let mut nodes = Vec::with_capacity(n_partitions);
nodes.push(old_node);
for i in 1..n_partitions {
nodes.push(graph.add_node(num_in, num_out));
clone_state(old_node, nodes[i], graph);
}
for (&node, in_ports) in nodes.iter().zip(ins).rev() {
let out_ports = in_ports
.iter()
.map(|&p| graph.unlink_port(p).expect("Disconnected port"))
.collect::<Vec<_>>()
.into_iter();
let new_in_ports = (0..in_ports.len())
.map(|i| {
graph
.port_index(node, PortOffset::new_incoming(i))
.expect("in_ports.len() <= num_in")
})
.collect::<Vec<_>>();
for ((new_out, new_in), &old_in) in out_ports.zip(new_in_ports).zip(in_ports) {
graph.link_ports(new_out, new_in).unwrap();
if old_in != new_in {
rekey(old_in, new_in.into());
}
}
}
let mut rekey = |old, new: PortOperation| rekey(old, new.new_index());
let mut cnts = BTreeMap::<_, usize>::new();
let mut outport_seen = BTreeSet::new();
for &out_port in outs.iter().flatten() {
if outport_seen.insert(out_port) {
continue;
}
let next_port = graph.port_link(out_port).expect("Disconnected port");
*cnts.entry(graph.port_node(next_port).unwrap()).or_default() += 1;
}
let mut new_port_offset: BTreeMap<_, _> = cnts
.keys()
.map(|&node| (node, graph.num_inputs(node)))
.collect();
for (&node, &add_port) in cnts.iter() {
graph.set_num_ports(
node,
graph.num_inputs(node) + add_port,
graph.num_outputs(node),
&mut rekey,
);
}
let links: BTreeMap<_, _> = graph
.outputs(old_node)
.map(|p| (p, graph.unlink_port(p).expect("Disconnected port")))
.collect();
for (&node, out_ports) in nodes.iter().zip(outs) {
let in_ports = out_ports
.iter()
.map(|&p| {
let in_port = links[&p];
if graph.port_link(in_port).is_none() {
in_port
} else {
let in_node = graph.port_node(in_port).expect("Invalid port");
let offset = new_port_offset.get_mut(&in_node).unwrap();
let new_in_port = graph
.port_index(in_node, PortOffset::new_incoming(*offset))
.expect("preallocated above");
*offset += 1;
new_in_port
}
})
.collect::<Vec<_>>()
.into_iter();
let new_out_ports = out_ports
.iter()
.map(|&p| {
let offset = graph.port_offset(p).expect("Invalid port");
graph.port_index(node, offset).expect("same sig")
})
.collect::<Vec<_>>()
.into_iter();
for (new_out, new_in) in new_out_ports.zip(in_ports) {
graph.link_ports(new_out, new_in).unwrap();
}
}
for &node in nodes.iter() {
let linked_ports = graph
.outputs(node)
.filter(|&p| graph.port_link(p).is_some())
.collect::<Vec<_>>() .into_iter();
for old_out in linked_ports {
let Some(new_out) = graph
.outputs(node)
.find(|&p| graph.port_link(p).is_none())
else { break };
if new_out < old_out {
let in_port = graph.unlink_port(old_out).expect("is linked");
graph.link_ports(new_out, in_port).expect("is free");
rekey(old_out, PortOperation::Moved { new_index: new_out });
}
}
}
for ((&node, ins), outs) in nodes.iter().zip(ins).zip(outs) {
graph.set_num_ports(node, ins.len(), outs.len(), &mut rekey);
}
nodes
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use portgraph::{PortGraph, PortOffset, UnmanagedDenseMap};
use super::untangle_threads;
#[test]
fn test_cover() {
let mut g = PortGraph::new();
let nodes = vec![
g.add_node(0, 2),
g.add_node(2, 1),
g.add_node(1, 2),
g.add_node(2, 1),
g.add_node(1, 1),
g.add_node(1, 0),
];
let edge = |g: &PortGraph, (n1, p1), (n2, p2)| {
(
g.port_index(nodes[n1], PortOffset::new_outgoing(p1))
.unwrap(),
g.port_index(nodes[n2], PortOffset::new_incoming(p2))
.unwrap(),
)
};
let threads = [
edge(&g, (0, 0), (1, 0)),
edge(&g, (0, 1), (2, 0)),
edge(&g, (1, 0), (3, 0)),
edge(&g, (2, 0), (1, 1)),
edge(&g, (3, 0), (5, 0)),
];
let other_edges = [edge(&g, (2, 1), (4, 0)), edge(&g, (4, 0), (3, 1))];
for (out_p, in_p) in threads.iter().copied().chain(other_edges) {
g.link_ports(out_p, in_p).unwrap();
}
let mut new_nodes = BTreeMap::new();
let mut trace: UnmanagedDenseMap<_, _> = Default::default();
let thread_inds = [
(vec![0], vec![1]),
(vec![0], vec![1]),
(vec![1, 2], vec![2, 3]),
(vec![1], vec![2]),
(vec![2, 3], vec![3, 4]),
];
for (&(out_port, in_port), (out_ind, in_ind)) in threads.iter().zip(thread_inds) {
trace[in_port] = (in_ind, false);
trace[out_port] = (out_ind, false);
}
untangle_threads(
&mut g,
trace,
root_state(),
|old_n, new_n, _| {
new_nodes.insert(old_n, new_n);
},
|_, _| {},
);
assert_eq!(g.node_count(), 11);
}
}