use onnx_ir::{Argument, Node};
use std::collections::HashMap;
use std::collections::HashSet;
pub(crate) const MIN_GRAPH_SIZE: usize = 200;
const MIN_CHUNK_SIZE: usize = 64;
const MAX_CHUNK_SIZE: usize = 256;
const MAX_CUT_WIDTH: usize = 64;
#[derive(Debug, Clone)]
pub(crate) struct Partition {
pub chunks: Vec<std::ops::Range<usize>>,
pub chunk_inputs: Vec<Vec<Argument>>,
pub chunk_outputs: Vec<Vec<Argument>>,
}
fn compute_cut_widths(nodes: &[Node], graph_output_names: &[String]) -> Vec<usize> {
let n = nodes.len();
let mut value_spans: HashMap<String, (usize, usize)> = HashMap::new();
let graph_out_set: std::collections::HashSet<&str> =
graph_output_names.iter().map(|s| s.as_str()).collect();
for (i, node) in nodes.iter().enumerate() {
for arg in node.outputs() {
if !arg.name.is_empty() {
value_spans
.entry(arg.name.clone())
.or_insert((i + 1, i + 1));
}
}
for arg in node.inputs() {
if !arg.name.is_empty() && (arg.is_dynamic() || arg.is_constant()) {
value_spans.entry(arg.name.clone()).and_modify(|(_, last)| {
if i > *last {
*last = i;
}
});
value_spans.entry(arg.name.clone()).or_insert((0, i));
}
}
}
for name in &graph_out_set {
if let Some((_, last)) = value_spans.get_mut(*name) {
*last = n;
}
}
let mut delta = vec![0i64; n + 2]; for &(producer, last_consumer) in value_spans.values() {
let start = producer + 1;
let end = last_consumer + 1;
if start <= n && start < end {
delta[start] += 1;
if end <= n {
delta[end] -= 1;
}
}
}
let mut widths = vec![0usize; n + 1];
let mut running: i64 = 0;
for p in 0..=n {
running += delta[p];
debug_assert!(
running >= 0,
"cut width went negative at position {p}, running = {running}"
);
widths[p] = running.max(0) as usize;
}
widths
}
fn find_partition_points(cut_widths: &[usize], node_count: usize) -> Vec<usize> {
if node_count < MIN_GRAPH_SIZE {
return vec![];
}
let mut candidates: Vec<(usize, usize)> = Vec::new(); for (p, &w) in cut_widths
.iter()
.enumerate()
.take(node_count.saturating_sub(MIN_CHUNK_SIZE) + 1)
.skip(MIN_CHUNK_SIZE)
{
if w <= MAX_CUT_WIDTH {
candidates.push((p, w));
}
}
if candidates.is_empty() {
log::warn!(
"Graph has {node_count} nodes but no partition points with \
cut width <= {MAX_CUT_WIDTH} were found; falling back to flat codegen"
);
return vec![];
}
let mut points = Vec::new();
let mut last_cut = 0;
loop {
let window_start = last_cut + MIN_CHUNK_SIZE;
let window_end = (last_cut + MAX_CHUNK_SIZE).min(node_count);
if window_start >= node_count.saturating_sub(MIN_CHUNK_SIZE) {
break;
}
let best = candidates
.iter()
.filter(|(p, _)| *p >= window_start && *p <= window_end)
.min_by_key(|(_, w)| *w);
if let Some(&(pos, _)) = best {
points.push(pos);
last_cut = pos;
} else {
log::debug!(
"No acceptable partition point in nodes [{window_start}..{window_end}], \
skipping to next window"
);
last_cut = window_end;
}
}
points
}
fn compute_chunk_interfaces(
nodes: &[Node],
chunks: &[std::ops::Range<usize>],
graph_input_args: &[Argument],
graph_output_args: &[Argument],
) -> (Vec<Vec<Argument>>, Vec<Vec<Argument>>) {
let num_chunks = chunks.len();
let mut chunk_inputs = vec![Vec::new(); num_chunks];
let mut chunk_outputs = vec![Vec::new(); num_chunks];
let mut producers: HashMap<String, (usize, Argument)> = HashMap::new();
for arg in graph_input_args {
producers.insert(arg.name.clone(), (usize::MAX, arg.clone()));
}
for (chunk_idx, range) in chunks.iter().enumerate() {
for node_idx in range.clone() {
for arg in nodes[node_idx].outputs() {
if !arg.name.is_empty() {
producers.insert(arg.name.clone(), (chunk_idx, arg.clone()));
}
}
}
}
let mut chunk_input_sets: Vec<std::collections::HashSet<String>> =
vec![std::collections::HashSet::new(); num_chunks];
let mut chunk_output_sets: Vec<std::collections::HashSet<String>> =
vec![std::collections::HashSet::new(); num_chunks];
for (chunk_idx, range) in chunks.iter().enumerate() {
for node_idx in range.clone() {
for arg in nodes[node_idx].inputs() {
if arg.name.is_empty() {
continue;
}
if !arg.is_dynamic() && !arg.is_constant() {
continue;
}
if let Some(&(producer_chunk, ref producer_arg)) = producers.get(&arg.name)
&& producer_chunk != chunk_idx
{
if chunk_input_sets[chunk_idx].insert(arg.name.clone()) {
chunk_inputs[chunk_idx].push(producer_arg.clone());
}
if producer_chunk != usize::MAX
&& chunk_output_sets[producer_chunk].insert(arg.name.clone())
{
chunk_outputs[producer_chunk].push(producer_arg.clone());
}
}
}
}
}
for arg in graph_output_args {
if let Some(&(producer_chunk, ref producer_arg)) = producers.get(&arg.name)
&& producer_chunk != usize::MAX
&& chunk_output_sets[producer_chunk].insert(arg.name.clone())
{
chunk_outputs[producer_chunk].push(producer_arg.clone());
}
}
(chunk_inputs, chunk_outputs)
}
pub(crate) fn reorder_constants_to_consumers(nodes: &mut Vec<Node>) {
let n = nodes.len();
if n == 0 {
return;
}
let mut is_constant = vec![false; n];
let mut const_output_to_idx: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if matches!(node, Node::Constant(_)) {
is_constant[i] = true;
for arg in node.outputs() {
if !arg.name.is_empty() {
const_output_to_idx.insert(arg.name.clone(), i);
}
}
}
}
if const_output_to_idx.is_empty() {
return;
}
let mut const_first_consumer: HashMap<usize, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for arg in node.inputs() {
if let Some(&const_idx) = const_output_to_idx.get(&arg.name) {
const_first_consumer.entry(const_idx).or_insert(i);
}
}
}
let mut consumer_to_constants: HashMap<usize, Vec<usize>> = HashMap::new();
let mut orphan_constants: Vec<usize> = Vec::new();
for (i, &is_const) in is_constant.iter().enumerate() {
if !is_const {
continue;
}
match const_first_consumer.get(&i) {
Some(&consumer) if consumer != i + 1 => {
consumer_to_constants.entry(consumer).or_default().push(i);
}
_ => {
orphan_constants.push(i);
}
}
}
if consumer_to_constants.is_empty() {
return; }
let relocated: HashSet<usize> = consumer_to_constants.values().flatten().copied().collect();
let mut new_order: Vec<usize> = Vec::with_capacity(n);
for i in 0..n {
if relocated.contains(&i) {
continue; }
if let Some(consts) = consumer_to_constants.get(&i) {
new_order.extend(consts);
}
new_order.push(i);
}
debug_assert_eq!(new_order.len(), n);
let mut slots: Vec<Option<Node>> = nodes.drain(..).map(Some).collect();
*nodes = new_order
.into_iter()
.map(|i| slots[i].take().expect("node used twice"))
.collect();
}
pub(crate) fn try_partition(
nodes: &[Node],
graph_input_args: &[Argument],
graph_output_args: &[Argument],
) -> Option<Partition> {
let n = nodes.len();
if n < MIN_GRAPH_SIZE {
return None;
}
let graph_output_names: Vec<String> =
graph_output_args.iter().map(|a| a.name.clone()).collect();
let cut_widths = compute_cut_widths(nodes, &graph_output_names);
let points = find_partition_points(&cut_widths, n);
if points.is_empty() {
return None;
}
let mut chunks = Vec::new();
let mut start = 0;
for &p in &points {
chunks.push(start..p);
start = p;
}
chunks.push(start..n);
let (chunk_inputs, chunk_outputs) =
compute_chunk_interfaces(nodes, &chunks, graph_input_args, graph_output_args);
Some(Partition {
chunks,
chunk_inputs,
chunk_outputs,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cut_widths_empty_graph() {
let widths = compute_cut_widths(&[], &[]);
assert_eq!(widths, vec![0]);
}
#[test]
fn partition_returns_none_for_small_graph() {
assert!(try_partition(&[], &[], &[]).is_none());
}
#[test]
fn find_partition_points_returns_empty_for_small() {
let widths = vec![0; 50];
assert!(find_partition_points(&widths, 49).is_empty());
}
#[test]
fn find_partition_points_finds_narrow_cut() {
let mut widths = vec![5usize; 301];
widths[150] = 1;
let points = find_partition_points(&widths, 300);
assert!(points.contains(&150));
}
#[test]
fn find_partition_points_multiple_cuts() {
let mut widths = vec![5usize; 601];
widths[200] = 1;
widths[400] = 1;
let points = find_partition_points(&widths, 600);
assert!(points.contains(&200));
assert!(points.contains(&400));
}
}