use std::collections::{HashMap, HashSet};
use crate::error::CompileError;
use bb_ir::proto::onnx::{FunctionProto, ModelProto, NodeProto};
const MODULE_CALL_DOMAIN: &str = "ai.bytesandbrains.module";
const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
const ONNX_DOMAIN: &str = "ai.onnx";
pub fn inline_for_partition(model: &mut ModelProto) -> Result<usize, CompileError> {
let root_name = model.functions.first().map(|f| f.name.clone());
let mut total_inlines: usize = 0;
let mut next_unique: u64 = 0;
loop {
let inlinable = classify_inlinable(model, root_name.as_deref());
if inlinable.is_empty() {
break;
}
let order = reverse_topo_order(model, &inlinable);
for name in order {
let body = match model.functions.iter().find(|f| f.name == name) {
Some(f) => f.clone(),
None => continue,
};
for caller in model.functions.iter_mut() {
if caller.name == name {
continue;
}
let mut rewritten: Vec<NodeProto> = Vec::with_capacity(caller.node.len());
let mut inlined_value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = Vec::new();
for node in caller.node.iter() {
if node.domain == MODULE_CALL_DOMAIN && node.op_type == name {
let (nodes, value_info) = inline_one_call(&body, node, &mut next_unique);
rewritten.extend(nodes);
inlined_value_info.extend(value_info);
total_inlines += 1;
} else {
rewritten.push(node.clone());
}
}
caller.node = rewritten;
for vi in inlined_value_info {
if !caller.value_info.iter().any(|v| v.name == vi.name) {
caller.value_info.push(vi);
}
}
}
}
model.functions.retain(|f| !inlinable.contains(&f.name));
}
Ok(total_inlines)
}
fn classify_inlinable(model: &ModelProto, root_name: Option<&str>) -> HashSet<String> {
let wire_touching = wire_closure(model);
let pure_onnx = pure_onnx_closure(model);
let call_counts = count_call_sites(model);
let mut result = HashSet::new();
for f in &model.functions {
if root_name == Some(f.name.as_str()) {
continue;
}
let single_call = call_counts.get(&f.name).copied() == Some(1);
if wire_touching.contains(&f.name) || pure_onnx.contains(&f.name) || single_call {
result.insert(f.name.clone());
}
}
result
}
fn count_call_sites(model: &ModelProto) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for f in &model.functions {
for node in &f.node {
if node.domain == MODULE_CALL_DOMAIN {
*counts.entry(node.op_type.clone()).or_insert(0) += 1;
}
}
}
counts
}
fn wire_closure(model: &ModelProto) -> HashSet<String> {
let mut closure: HashSet<String> = model
.functions
.iter()
.filter(|f| f.node.iter().any(|n| n.domain == WIRE_DOMAIN))
.map(|f| f.name.clone())
.collect();
loop {
let mut changed = false;
for f in &model.functions {
if closure.contains(&f.name) {
continue;
}
if f.node
.iter()
.any(|n| n.domain == MODULE_CALL_DOMAIN && closure.contains(&n.op_type))
{
closure.insert(f.name.clone());
changed = true;
}
}
if !changed {
break;
}
}
closure
}
fn pure_onnx_closure(model: &ModelProto) -> HashSet<String> {
let mut closure: HashSet<String> = HashSet::new();
loop {
let mut changed = false;
for f in &model.functions {
if closure.contains(&f.name) {
continue;
}
let all_ok = !f.node.is_empty()
&& f.node.iter().all(|n| {
if n.domain == MODULE_CALL_DOMAIN {
closure.contains(&n.op_type)
} else {
n.domain == ONNX_DOMAIN
}
});
if all_ok {
closure.insert(f.name.clone());
changed = true;
}
}
if !changed {
break;
}
}
closure
}
fn reverse_topo_order(model: &ModelProto, inlinable: &HashSet<String>) -> Vec<String> {
let inlinable_idx: HashMap<String, usize> = model
.functions
.iter()
.enumerate()
.filter(|(_, f)| inlinable.contains(&f.name))
.map(|(i, f)| (f.name.clone(), i))
.collect();
let mut visited: HashSet<String> = HashSet::new();
let mut order: Vec<String> = Vec::new();
fn visit(
name: &str,
model: &ModelProto,
inlinable_idx: &HashMap<String, usize>,
visited: &mut HashSet<String>,
order: &mut Vec<String>,
) {
if !visited.insert(name.to_string()) {
return;
}
let Some(&idx) = inlinable_idx.get(name) else {
return;
};
let f = &model.functions[idx];
for node in &f.node {
if node.domain == MODULE_CALL_DOMAIN && inlinable_idx.contains_key(&node.op_type) {
visit(&node.op_type, model, inlinable_idx, visited, order);
}
}
order.push(name.to_string());
}
let names: Vec<String> = inlinable_idx.keys().cloned().collect();
for name in &names {
visit(name, model, &inlinable_idx, &mut visited, &mut order);
}
order
}
fn inline_one_call(
body: &FunctionProto,
call: &NodeProto,
next_unique: &mut u64,
) -> (Vec<NodeProto>, Vec<bb_ir::proto::onnx::ValueInfoProto>) {
let unique = *next_unique;
*next_unique = next_unique.saturating_add(1);
let mut rename: HashMap<String, String> = HashMap::new();
for (i, formal) in body.input.iter().enumerate() {
if let Some(actual) = call.input.get(i) {
rename.insert(formal.clone(), actual.clone());
}
}
for (i, body_out) in body.output.iter().enumerate() {
if let Some(call_out) = call.output.get(i) {
rename.insert(body_out.clone(), call_out.clone());
}
}
let mut rename_value = |name: &str| -> String {
if name.is_empty() {
return String::new();
}
if let Some(renamed) = rename.get(name) {
return renamed.clone();
}
let fresh = format!("{name}#inl{unique}");
rename.insert(name.to_string(), fresh.clone());
fresh
};
let mut out: Vec<NodeProto> = Vec::with_capacity(body.node.len());
for node in &body.node {
let mut cloned = node.clone();
for input in cloned.input.iter_mut() {
*input = rename_value(input);
}
for output in cloned.output.iter_mut() {
*output = rename_value(output);
}
out.push(cloned);
}
let value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = body
.value_info
.iter()
.filter_map(|vi| {
let new_name = rename.get(&vi.name).cloned()?;
let mut renamed = vi.clone();
renamed.name = new_name;
Some(renamed)
})
.collect();
(out, value_info)
}