bb_compiler/
inline_for_partition.rs1use std::collections::{HashMap, HashSet};
34
35use crate::error::CompileError;
36use bb_ir::proto::onnx::{FunctionProto, ModelProto, NodeProto};
37
38const MODULE_CALL_DOMAIN: &str = "ai.bytesandbrains.module";
39const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
40const ONNX_DOMAIN: &str = "ai.onnx";
41
42pub fn inline_for_partition(model: &mut ModelProto) -> Result<usize, CompileError> {
47 let root_name = model.functions.first().map(|f| f.name.clone());
48 let mut total_inlines: usize = 0;
49 let mut next_unique: u64 = 0;
50
51 loop {
52 let inlinable = classify_inlinable(model, root_name.as_deref());
53 if inlinable.is_empty() {
54 break;
55 }
56 let order = reverse_topo_order(model, &inlinable);
57
58 for name in order {
59 let body = match model.functions.iter().find(|f| f.name == name) {
62 Some(f) => f.clone(),
63 None => continue,
64 };
65
66 for caller in model.functions.iter_mut() {
67 if caller.name == name {
68 continue;
69 }
70 let mut rewritten: Vec<NodeProto> = Vec::with_capacity(caller.node.len());
71 let mut inlined_value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = Vec::new();
72 for node in caller.node.iter() {
73 if node.domain == MODULE_CALL_DOMAIN && node.op_type == name {
74 let (nodes, value_info) = inline_one_call(&body, node, &mut next_unique);
75 rewritten.extend(nodes);
76 inlined_value_info.extend(value_info);
77 total_inlines += 1;
78 } else {
79 rewritten.push(node.clone());
80 }
81 }
82 caller.node = rewritten;
83 for vi in inlined_value_info {
87 if !caller.value_info.iter().any(|v| v.name == vi.name) {
88 caller.value_info.push(vi);
89 }
90 }
91 }
92 }
93
94 model.functions.retain(|f| !inlinable.contains(&f.name));
97 }
98
99 Ok(total_inlines)
100}
101
102fn classify_inlinable(model: &ModelProto, root_name: Option<&str>) -> HashSet<String> {
106 let wire_touching = wire_closure(model);
107 let pure_onnx = pure_onnx_closure(model);
108 let call_counts = count_call_sites(model);
109
110 let mut result = HashSet::new();
111 for f in &model.functions {
112 if root_name == Some(f.name.as_str()) {
113 continue;
114 }
115 let single_call = call_counts.get(&f.name).copied() == Some(1);
116 if wire_touching.contains(&f.name) || pure_onnx.contains(&f.name) || single_call {
117 result.insert(f.name.clone());
118 }
119 }
120 result
121}
122
123fn count_call_sites(model: &ModelProto) -> HashMap<String, usize> {
126 let mut counts: HashMap<String, usize> = HashMap::new();
127 for f in &model.functions {
128 for node in &f.node {
129 if node.domain == MODULE_CALL_DOMAIN {
130 *counts.entry(node.op_type.clone()).or_insert(0) += 1;
131 }
132 }
133 }
134 counts
135}
136
137fn wire_closure(model: &ModelProto) -> HashSet<String> {
142 let mut closure: HashSet<String> = model
143 .functions
144 .iter()
145 .filter(|f| f.node.iter().any(|n| n.domain == WIRE_DOMAIN))
146 .map(|f| f.name.clone())
147 .collect();
148
149 loop {
150 let mut changed = false;
151 for f in &model.functions {
152 if closure.contains(&f.name) {
153 continue;
154 }
155 if f.node
156 .iter()
157 .any(|n| n.domain == MODULE_CALL_DOMAIN && closure.contains(&n.op_type))
158 {
159 closure.insert(f.name.clone());
160 changed = true;
161 }
162 }
163 if !changed {
164 break;
165 }
166 }
167 closure
168}
169
170fn pure_onnx_closure(model: &ModelProto) -> HashSet<String> {
175 let mut closure: HashSet<String> = HashSet::new();
176 loop {
177 let mut changed = false;
178 for f in &model.functions {
179 if closure.contains(&f.name) {
180 continue;
181 }
182 let all_ok = !f.node.is_empty()
186 && f.node.iter().all(|n| {
187 if n.domain == MODULE_CALL_DOMAIN {
188 closure.contains(&n.op_type)
189 } else {
190 n.domain == ONNX_DOMAIN
191 }
192 });
193 if all_ok {
194 closure.insert(f.name.clone());
195 changed = true;
196 }
197 }
198 if !changed {
199 break;
200 }
201 }
202 closure
203}
204
205fn reverse_topo_order(model: &ModelProto, inlinable: &HashSet<String>) -> Vec<String> {
210 let inlinable_idx: HashMap<String, usize> = model
211 .functions
212 .iter()
213 .enumerate()
214 .filter(|(_, f)| inlinable.contains(&f.name))
215 .map(|(i, f)| (f.name.clone(), i))
216 .collect();
217
218 let mut visited: HashSet<String> = HashSet::new();
219 let mut order: Vec<String> = Vec::new();
220
221 fn visit(
222 name: &str,
223 model: &ModelProto,
224 inlinable_idx: &HashMap<String, usize>,
225 visited: &mut HashSet<String>,
226 order: &mut Vec<String>,
227 ) {
228 if !visited.insert(name.to_string()) {
229 return;
230 }
231 let Some(&idx) = inlinable_idx.get(name) else {
232 return;
233 };
234 let f = &model.functions[idx];
235 for node in &f.node {
236 if node.domain == MODULE_CALL_DOMAIN && inlinable_idx.contains_key(&node.op_type) {
237 visit(&node.op_type, model, inlinable_idx, visited, order);
238 }
239 }
240 order.push(name.to_string());
241 }
242
243 let names: Vec<String> = inlinable_idx.keys().cloned().collect();
244 for name in &names {
245 visit(name, model, &inlinable_idx, &mut visited, &mut order);
246 }
247 order
248}
249
250fn inline_one_call(
256 body: &FunctionProto,
257 call: &NodeProto,
258 next_unique: &mut u64,
259) -> (Vec<NodeProto>, Vec<bb_ir::proto::onnx::ValueInfoProto>) {
260 let unique = *next_unique;
261 *next_unique = next_unique.saturating_add(1);
262
263 let mut rename: HashMap<String, String> = HashMap::new();
264 for (i, formal) in body.input.iter().enumerate() {
265 if let Some(actual) = call.input.get(i) {
266 rename.insert(formal.clone(), actual.clone());
267 }
268 }
269 for (i, body_out) in body.output.iter().enumerate() {
270 if let Some(call_out) = call.output.get(i) {
271 rename.insert(body_out.clone(), call_out.clone());
272 }
273 }
274
275 let mut rename_value = |name: &str| -> String {
276 if name.is_empty() {
277 return String::new();
278 }
279 if let Some(renamed) = rename.get(name) {
280 return renamed.clone();
281 }
282 let fresh = format!("{name}#inl{unique}");
283 rename.insert(name.to_string(), fresh.clone());
284 fresh
285 };
286
287 let mut out: Vec<NodeProto> = Vec::with_capacity(body.node.len());
288 for node in &body.node {
289 let mut cloned = node.clone();
290 for input in cloned.input.iter_mut() {
291 *input = rename_value(input);
292 }
293 for output in cloned.output.iter_mut() {
294 *output = rename_value(output);
295 }
296 out.push(cloned);
297 }
298
299 let value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = body
302 .value_info
303 .iter()
304 .filter_map(|vi| {
305 let new_name = rename.get(&vi.name).cloned()?;
306 let mut renamed = vi.clone();
307 renamed.name = new_name;
308 Some(renamed)
309 })
310 .collect();
311
312 (out, value_info)
313}
314