use std::collections::{BTreeMap, BTreeSet};
use itertools::Itertools;
use proc_macro2::Span;
use slotmap::{SecondaryMap, SparseSecondaryMap};
use super::meta_graph::DfirGraph;
use super::ops::{DelayType, FloType};
use super::{
Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, HandoffKind, graph_algorithms,
};
use crate::diagnostic::{Diagnostic, Level};
use crate::union_find::UnionFind;
struct BarrierCrossers {
pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
}
impl BarrierCrossers {
fn iter_node_pairs<'a>(
&'a self,
partitioned_graph: &'a DfirGraph,
) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
let edge_pairs_iter = self
.edge_barrier_crossers
.iter()
.map(|(edge_id, &delay_type)| {
let src_dst = partitioned_graph.edge(edge_id);
(src_dst, delay_type)
});
let singleton_pairs_iter = self
.singleton_barrier_crossers
.iter()
.map(|&src_dst| (src_dst, DelayType::Stratum));
edge_pairs_iter.chain(singleton_pairs_iter)
}
fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
self.edge_barrier_crossers.insert(new_edge_id, delay_type);
}
}
}
fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
let edge_barrier_crossers = partitioned_graph
.edges()
.filter(|&(_, (_src, dst))| {
partitioned_graph.node_loop(dst).is_none()
})
.filter_map(|(edge_id, (_src, dst))| {
let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
Some((edge_id, input_barrier))
})
.collect();
let mut singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)> = partitioned_graph
.node_ids()
.flat_map(|dst| {
partitioned_graph
.node_singleton_references(dst)
.iter()
.filter_map(|r| r.node_id)
.map(move |src_ref| (src_ref, dst))
})
.collect();
let refs_by_target = partitioned_graph.node_singleton_reference_groups();
for (_singleton, groups) in refs_by_target {
for (group_a, group_b) in groups.values().tuple_windows() {
for &(node_a, _, _) in group_a {
for &(node_b, _, _) in group_b {
assert_ne!(
node_a, node_b,
"encounted conflicted or cyclical singleton references\n{:?}\n{:?}",
group_a, group_b,
);
singleton_barrier_crossers.push((node_a, node_b));
}
}
}
}
BarrierCrossers {
edge_barrier_crossers,
singleton_barrier_crossers,
}
}
fn find_subgraph_unionfind(
partitioned_graph: &DfirGraph,
barrier_crossers: &BarrierCrossers,
) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
let mut node_color = partitioned_graph
.node_ids()
.filter_map(|node_id| {
let op_color = partitioned_graph.node_color(node_id)?;
Some((node_id, op_color))
})
.collect::<SparseSecondaryMap<_, _>>();
let mut subgraph_unionfind: UnionFind<GraphNodeId> =
UnionFind::with_capacity(partitioned_graph.nodes().len());
let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
let mut progress = true;
while progress {
progress = false;
for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
if subgraph_unionfind.same_set(src, dst) {
continue;
}
if barrier_crossers
.iter_node_pairs(partitioned_graph)
.any(|((x_src, x_dst), _)| {
(subgraph_unionfind.same_set(x_src, src)
&& subgraph_unionfind.same_set(x_dst, dst))
|| (subgraph_unionfind.same_set(x_src, dst)
&& subgraph_unionfind.same_set(x_dst, src))
})
{
continue;
}
if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
continue;
}
if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
}) {
continue;
}
if can_connect_colorize(&mut node_color, src, dst) {
subgraph_unionfind.union(src, dst);
assert!(handoff_edges.remove(&edge_id));
progress = true;
}
}
}
(subgraph_unionfind, handoff_edges)
}
fn make_subgraph_collect(
partitioned_graph: &DfirGraph,
mut subgraph_unionfind: UnionFind<GraphNodeId>,
) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
let topo_sort = graph_algorithms::topo_sort(
partitioned_graph
.nodes()
.filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
.map(|(node_id, _)| node_id),
|v| {
partitioned_graph
.node_predecessor_nodes(v)
.filter(|&pred_id| {
let pred = partitioned_graph.node(pred_id);
!matches!(pred, GraphNode::Handoff { .. })
})
},
)
.expect("Subgraphs are in-out trees.");
let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
for node_id in topo_sort {
let repr_node = subgraph_unionfind.find(node_id);
if !grouped_nodes.contains_key(repr_node) {
grouped_nodes.insert(repr_node, Default::default());
}
grouped_nodes[repr_node].push(node_id);
}
grouped_nodes
}
fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
let (subgraph_unionfind, handoff_edges) =
find_subgraph_unionfind(partitioned_graph, barrier_crossers);
for edge_id in handoff_edges {
let (src_id, dst_id) = partitioned_graph.edge(edge_id);
let src_node = partitioned_graph.node(src_id);
let dst_node = partitioned_graph.node(dst_id);
if matches!(src_node, GraphNode::Handoff { .. })
|| matches!(dst_node, GraphNode::Handoff { .. })
{
continue;
}
let hoff = GraphNode::Handoff {
kind: HandoffKind::Vec,
src_span: src_node.span(),
dst_span: dst_node.span(),
};
let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
barrier_crossers.replace_edge(edge_id, out_edge_id);
}
let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
for (_repr_node, member_nodes) in grouped_nodes {
partitioned_graph.insert_subgraph(member_nodes).unwrap();
}
}
fn can_connect_colorize(
node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
src: GraphNodeId,
dst: GraphNodeId,
) -> bool {
let can_connect = match (node_color.get(src), node_color.get(dst)) {
(None, None) => false,
(None, Some(Color::Pull | Color::Comp)) => {
node_color.insert(src, Color::Pull);
true
}
(None, Some(Color::Push | Color::Hoff)) => {
node_color.insert(src, Color::Push);
true
}
(Some(Color::Pull | Color::Hoff), None) => {
node_color.insert(dst, Color::Pull);
true
}
(Some(Color::Comp | Color::Push), None) => {
node_color.insert(dst, Color::Push);
true
}
(Some(Color::Pull), Some(Color::Pull)) => true,
(Some(Color::Pull), Some(Color::Comp)) => true,
(Some(Color::Pull), Some(Color::Push)) => true,
(Some(Color::Comp), Some(Color::Pull)) => false,
(Some(Color::Comp), Some(Color::Comp)) => false,
(Some(Color::Comp), Some(Color::Push)) => true,
(Some(Color::Push), Some(Color::Pull)) => false,
(Some(Color::Push), Some(Color::Comp)) => false,
(Some(Color::Push), Some(Color::Push)) => true,
(Some(Color::Hoff), Some(_)) => false,
(Some(_), Some(Color::Hoff)) => false,
};
can_connect
}
fn order_subgraphs(
partitioned_graph: &mut DfirGraph,
barrier_crossers: &BarrierCrossers,
) -> Result<(), Diagnostic> {
let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
for (hoff_id, hoff) in partitioned_graph.nodes() {
if !matches!(hoff, GraphNode::Handoff { .. }) {
continue;
}
if partitioned_graph.node_degree_out(hoff_id) == 0 {
continue;
}
assert_eq!(1, partitioned_graph.node_degree_out(hoff_id));
let (succ_edge, succ) = partitioned_graph.node_successors(hoff_id).next().unwrap();
let succ_edge_delaytype = barrier_crossers
.edge_barrier_crossers
.get(succ_edge)
.copied();
if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
tick_edges.push((succ_edge, delay_type));
continue;
}
assert_eq!(1, partitioned_graph.node_degree_in(hoff_id));
let (_edge_id, pred) = partitioned_graph.node_predecessors(hoff_id).next().unwrap();
let pred_sg = partitioned_graph
.node_subgraph(pred)
.expect("Handoff pred not in subgraph, may be a doubled/adjacent handoff");
let succ_sg = partitioned_graph
.node_subgraph(succ)
.expect("Handoff succ not in subgraph, may be a doubled/adjacent handoff");
sg_preds.entry(succ_sg).or_default().push(pred_sg);
}
for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
assert_ne!(pred, succ);
let pred_sg = if let Some(sg) = partitioned_graph.node_subgraph(pred) {
sg
} else {
let (_edge, pred_pred) = partitioned_graph
.node_predecessors(pred)
.next()
.expect("handoff must have a predecessor");
partitioned_graph.node_subgraph(pred_pred).unwrap()
};
let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
if pred_sg == succ_sg {
continue;
}
sg_preds.entry(succ_sg).or_default().push(pred_sg);
if matches!(partitioned_graph.node(pred), GraphNode::Handoff { .. }) {
assert!(
partitioned_graph.node_degree_out(pred) <= 1,
"handoff should have at most one successor"
);
if let Some((_edge, consumer)) = partitioned_graph.node_successors(pred).next() {
let consumer_sg = partitioned_graph.node_subgraph(consumer).unwrap();
if consumer_sg != succ_sg {
sg_preds.entry(consumer_sg).or_default().push(succ_sg);
}
}
}
}
if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
sg_preds.get(&v).into_iter().flatten().copied()
}) {
let span = cycle
.first()
.and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
.map(|n| partitioned_graph.node(n).span())
.unwrap_or_else(Span::call_site);
return Err(Diagnostic::spanned(
span,
Level::Error,
format!("Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.
Cycle: {:?}", cycle),
));
}
for (edge_id, delay_type) in tick_edges {
let (hoff, _dst) = partitioned_graph.edge(edge_id);
assert!(matches!(
partitioned_graph.node(hoff),
GraphNode::Handoff {
kind: HandoffKind::Vec,
..
}
));
partitioned_graph.set_handoff_delay_type(hoff, delay_type);
}
Ok(())
}
pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
let mut barrier_crossers = find_barrier_crossers(&flat_graph);
let mut partitioned_graph = flat_graph;
make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
Ok(partitioned_graph)
}