use std::collections::BTreeSet;
use itertools::Itertools;
use slotmap::{SecondaryMap, SparseSecondaryMap};
use super::meta_graph::DfirGraph;
use super::ops::{DelayType, FloType};
use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, HandoffKind};
use crate::diagnostic::{Diagnostic, Level};
use crate::graph::graph_algorithms::SubgraphMerge;
fn find_edge_barriers(
partitioned_graph: &DfirGraph,
) -> (
SecondaryMap<GraphEdgeId, DelayType>,
Vec<(GraphNodeId, GraphNodeId)>,
) {
let mut tick_edges = SecondaryMap::new();
let mut barrier_pairs = Vec::new();
for (edge_id, (src, dst)) in partitioned_graph.edges() {
if partitioned_graph.node_loop(dst).is_some() {
continue;
}
let Some(op_inst) = partitioned_graph.node_op_inst(dst) else {
continue;
};
let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
let Some(delay_type) = (op_inst.op_constraints.input_delaytype_fn)(dst_port) else {
continue;
};
barrier_pairs.push((src, dst));
if matches!(delay_type, DelayType::Tick | DelayType::TickLazy) {
tick_edges.insert(edge_id, delay_type);
}
}
(tick_edges, barrier_pairs)
}
fn find_access_group_ordering(partitioned_graph: &DfirGraph) -> Vec<(GraphNodeId, GraphNodeId)> {
let mut pairs = Vec::new();
let refs_by_target = partitioned_graph.node_handoff_reference_groups();
for (_handoff, groups) in refs_by_target {
for (group_a, group_b) in groups.values().tuple_windows() {
for &(node_a, _, _) in group_a {
for &(node_b, _, _) in group_b {
assert_ne!(
node_a, node_b,
"encounted conflicted or cyclical handoff references\n{:?}\n{:?}",
group_a, group_b,
);
pairs.push((node_a, node_b));
}
}
}
}
pairs
}
fn find_subgraph_unionfind(
partitioned_graph: &DfirGraph,
tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
access_group_pairs: &[(GraphNodeId, GraphNodeId)],
) -> Result<(SubgraphMerge<GraphNodeId>, BTreeSet<GraphEdgeId>), Diagnostic> {
let mut node_color = partitioned_graph
.node_ids()
.filter_map(|node_id| {
let op_color = partitioned_graph.node_color(node_id)?;
Some((node_id, op_color))
})
.collect::<SparseSecondaryMap<_, _>>();
let mut all_preds: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = SecondaryMap::new();
for (edge_id, (src, dst)) in partitioned_graph.edges() {
if !tick_edges.contains_key(edge_id) {
all_preds.entry(dst).unwrap().or_default().push(src);
}
}
for node_id in partitioned_graph.node_ids() {
for handoff_ref in partitioned_graph.node_handoff_references(node_id).iter() {
if let Some(src) = handoff_ref.node_id {
all_preds.entry(node_id).unwrap().or_default().push(src);
if let GraphNode::Handoff { .. } = partitioned_graph.node(src) {
for (_edge, consumer) in partitioned_graph.node_successors(src) {
all_preds
.entry(consumer)
.unwrap()
.or_default()
.push(node_id);
}
}
}
}
}
for &(src, dst) in access_group_pairs {
all_preds.entry(dst).unwrap().or_default().push(src);
}
let enemies = edge_barrier_pairs
.iter()
.copied()
.chain(access_group_pairs.iter().copied())
.chain(partitioned_graph.node_ids().flat_map(|dst| {
partitioned_graph
.node_handoff_references(dst)
.iter()
.filter_map(|r| r.node_id)
.map(move |src| (src, dst))
}));
let mut subgraph_unionfind = SubgraphMerge::<GraphNodeId>::new(
partitioned_graph.node_ids(),
|node_id| all_preds.get(node_id).into_iter().flatten().copied(),
enemies,
)
.map_err(|cycle| {
let span = cycle
.first()
.map(|&node_id| partitioned_graph.node(node_id).span())
.unwrap_or_else(proc_macro2::Span::call_site);
let node_cycle = cycle
.iter()
.map(|&node_id| partitioned_graph.node(node_id).to_pretty_string())
.collect::<Vec<_>>();
Diagnostic::spanned(
span,
Level::Error,
format!(
"Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks. \
Cycle: {:?}",
node_cycle,
),
)
})?;
let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
let mut progress = true;
while progress {
progress = false;
for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
if matches!(partitioned_graph.node(src), GraphNode::Handoff { .. })
|| matches!(partitioned_graph.node(dst), GraphNode::Handoff { .. })
{
handoff_edges.remove(&edge_id);
continue;
}
if subgraph_unionfind.same_set(src, dst) {
continue;
}
if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
continue;
}
if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
}) {
continue;
}
if can_connect_colorize(&mut node_color, src, dst) {
let ok = subgraph_unionfind.try_merge(src, dst);
if ok {
assert!(handoff_edges.remove(&edge_id));
progress = true;
}
}
}
}
Ok((subgraph_unionfind, handoff_edges))
}
fn make_subgraphs(
partitioned_graph: &mut DfirGraph,
tick_edges: &mut SecondaryMap<GraphEdgeId, DelayType>,
edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
access_group_pairs: &[(GraphNodeId, GraphNodeId)],
) -> Result<(), Diagnostic> {
let (subgraph_merge, handoff_edges) = find_subgraph_unionfind(
partitioned_graph,
tick_edges,
edge_barrier_pairs,
access_group_pairs,
)?;
for edge_id in handoff_edges {
let (src_id, dst_id) = partitioned_graph.edge(edge_id);
let src_node = partitioned_graph.node(src_id);
let dst_node = partitioned_graph.node(dst_id);
if matches!(src_node, GraphNode::Handoff { .. })
|| matches!(dst_node, GraphNode::Handoff { .. })
{
continue;
}
let hoff = GraphNode::Handoff {
kind: HandoffKind::Vec,
src_span: src_node.span(),
dst_span: dst_node.span(),
};
let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
if let Some(delay_type) = tick_edges.remove(edge_id) {
tick_edges.insert(out_edge_id, delay_type);
}
}
let mut subgraph_toposort = Vec::new();
for nodes in subgraph_merge.subgraphs() {
if nodes.is_empty() {
continue;
}
if nodes
.iter()
.any(|&n| matches!(partitioned_graph.node(n), GraphNode::Handoff { .. }))
{
continue;
}
let sg_id = partitioned_graph.insert_subgraph(nodes.to_vec()).unwrap();
subgraph_toposort.push(sg_id);
}
partitioned_graph.set_subgraph_toposort(subgraph_toposort);
Ok(())
}
fn can_connect_colorize(
node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
src: GraphNodeId,
dst: GraphNodeId,
) -> bool {
let can_connect = match (node_color.get(src), node_color.get(dst)) {
(None, None) => false,
(None, Some(Color::Pull | Color::Comp)) => {
node_color.insert(src, Color::Pull);
true
}
(None, Some(Color::Push | Color::Hoff)) => {
node_color.insert(src, Color::Push);
true
}
(Some(Color::Pull | Color::Hoff), None) => {
node_color.insert(dst, Color::Pull);
true
}
(Some(Color::Comp | Color::Push), None) => {
node_color.insert(dst, Color::Push);
true
}
(Some(Color::Pull), Some(Color::Pull)) => true,
(Some(Color::Pull), Some(Color::Comp)) => true,
(Some(Color::Pull), Some(Color::Push)) => true,
(Some(Color::Comp), Some(Color::Pull)) => false,
(Some(Color::Comp), Some(Color::Comp)) => false,
(Some(Color::Comp), Some(Color::Push)) => true,
(Some(Color::Push), Some(Color::Pull)) => false,
(Some(Color::Push), Some(Color::Comp)) => false,
(Some(Color::Push), Some(Color::Push)) => true,
(Some(Color::Hoff), Some(_)) => false,
(Some(_), Some(Color::Hoff)) => false,
};
can_connect
}
fn mark_tick_boundary_handoffs(
partitioned_graph: &mut DfirGraph,
tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
) {
let tick_handoffs: Vec<_> = partitioned_graph
.nodes()
.filter_map(|(hoff_id, hoff)| {
if !matches!(hoff, GraphNode::Handoff { .. }) {
return None;
}
if partitioned_graph.node_degree_out(hoff_id) == 0 {
return None;
}
let (succ_edge, _) = partitioned_graph.node_successors(hoff_id).next().unwrap();
let &delay_type = tick_edges.get(succ_edge)?;
Some((hoff_id, delay_type))
})
.collect();
for (hoff_id, delay_type) in tick_handoffs {
partitioned_graph.set_handoff_delay_type(hoff_id, delay_type);
}
}
pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
let (mut tick_edges, edge_barrier_pairs) = find_edge_barriers(&flat_graph);
let access_group_pairs = find_access_group_ordering(&flat_graph);
let mut partitioned_graph = flat_graph;
make_subgraphs(
&mut partitioned_graph,
&mut tick_edges,
&edge_barrier_pairs,
&access_group_pairs,
)?;
mark_tick_boundary_handoffs(&mut partitioned_graph, &tick_edges);
Ok(partitioned_graph)
}