Skip to main content

bb_compiler/
inline_for_partition.rs

1//! Selective function inlining for the partition stage.
2//!
3//! Modules don't drive network boundaries any more — wire ops do.
4//! Three classes of functions get inlined at every CALL site before
5//! [`crate::partition_by_wire_ops`] runs:
6//!
7//! 1. **Wire-touching functions** — any function whose transitive
8//!    closure contains an `ai.bytesandbrains.wire` op. Wire ops must
9//!    live at the top level so the partitioner's reachability walk
10//!    cuts the graph cleanly; a wire hidden behind a CALL fragments
11//!    the partition boundary.
12//!
13//! 2. **Pure-ONNX functions** — any function whose transitive closure
14//!    is entirely `ai.onnx.*`. Inlining surfaces each NodeProto at
15//!    the top level so the engine can route per-op against the
16//!    bound `Backend`'s Contract methods (`add`, `matmul`, …)
17//!    without an intervening CALL indirection.
18//!
19//! 3. **Single-call functions** — any function CALLed from exactly
20//!    one site across the whole call graph. Keeping it as a
21//!    FunctionProto saves no memory (the body would appear once
22//!    either way) but adds an indirection at dispatch time, so we
23//!    inline it eagerly.
24//!
25//! Functions called from multiple sites (and not in classes 1 or 2)
26//! survive as `FunctionProto` — the body appears once, callers
27//! reference it via CALL nodes.
28//!
29//! The root function (`model.functions[0]`) is always preserved
30//! regardless of its body's classification — it's the entry point
31//! the compiler partitions on.
32
33use std::collections::{HashMap, HashSet};
34
35use crate::error::CompileError;
36use bb_ir::proto::onnx::{FunctionProto, ModelProto, NodeProto};
37
38const MODULE_CALL_DOMAIN: &str = "ai.bytesandbrains.module";
39const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
40const ONNX_DOMAIN: &str = "ai.onnx";
41
42/// Inline every wire-touching or pure-ONNX function at every CALL
43/// site. Iterates to a fixed point because inlining a wire-touching
44/// callee may itself reveal a new wire-touching caller. Returns the
45/// total number of CALL replacements performed.
46pub fn inline_for_partition(model: &mut ModelProto) -> Result<usize, CompileError> {
47    let root_name = model.functions.first().map(|f| f.name.clone());
48    let mut total_inlines: usize = 0;
49    let mut next_unique: u64 = 0;
50
51    loop {
52        let inlinable = classify_inlinable(model, root_name.as_deref());
53        if inlinable.is_empty() {
54            break;
55        }
56        let order = reverse_topo_order(model, &inlinable);
57
58        for name in order {
59            // Snapshot the body — we splice copies of it into each
60            // CALL site.
61            let body = match model.functions.iter().find(|f| f.name == name) {
62                Some(f) => f.clone(),
63                None => continue,
64            };
65
66            for caller in model.functions.iter_mut() {
67                if caller.name == name {
68                    continue;
69                }
70                let mut rewritten: Vec<NodeProto> = Vec::with_capacity(caller.node.len());
71                let mut inlined_value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = Vec::new();
72                for node in caller.node.iter() {
73                    if node.domain == MODULE_CALL_DOMAIN && node.op_type == name {
74                        let (nodes, value_info) = inline_one_call(&body, node, &mut next_unique);
75                        rewritten.extend(nodes);
76                        inlined_value_info.extend(value_info);
77                        total_inlines += 1;
78                    } else {
79                        rewritten.push(node.clone());
80                    }
81                }
82                caller.node = rewritten;
83                // Inlined sub-function value_info entries ride into
84                // the caller so strict-types-by-default sees a
85                // declared denotation for every renamed intermediate.
86                for vi in inlined_value_info {
87                    if !caller.value_info.iter().any(|v| v.name == vi.name) {
88                        caller.value_info.push(vi);
89                    }
90                }
91            }
92        }
93
94        // Drop the inlined functions from the table — their bodies
95        // now live at every former call site.
96        model.functions.retain(|f| !inlinable.contains(&f.name));
97    }
98
99    Ok(total_inlines)
100}
101
102/// Classify every non-root function as inlinable or kept. A
103/// function is inlinable iff it's wire-touching, pure-ONNX, or
104/// called from exactly one site.
105fn classify_inlinable(model: &ModelProto, root_name: Option<&str>) -> HashSet<String> {
106    let wire_touching = wire_closure(model);
107    let pure_onnx = pure_onnx_closure(model);
108    let call_counts = count_call_sites(model);
109
110    let mut result = HashSet::new();
111    for f in &model.functions {
112        if root_name == Some(f.name.as_str()) {
113            continue;
114        }
115        let single_call = call_counts.get(&f.name).copied() == Some(1);
116        if wire_touching.contains(&f.name) || pure_onnx.contains(&f.name) || single_call {
117            result.insert(f.name.clone());
118        }
119    }
120    result
121}
122
123/// Count CALL sites referencing each function across the model.
124/// Keyed by callee function name.
125fn count_call_sites(model: &ModelProto) -> HashMap<String, usize> {
126    let mut counts: HashMap<String, usize> = HashMap::new();
127    for f in &model.functions {
128        for node in &f.node {
129            if node.domain == MODULE_CALL_DOMAIN {
130                *counts.entry(node.op_type.clone()).or_insert(0) += 1;
131            }
132        }
133    }
134    counts
135}
136
137/// Functions whose transitive closure contains any wire op. Computed
138/// by starting at functions with a direct wire op in their body and
139/// propagating up the call graph (any caller of a wire-touching
140/// function is itself wire-touching).
141fn wire_closure(model: &ModelProto) -> HashSet<String> {
142    let mut closure: HashSet<String> = model
143        .functions
144        .iter()
145        .filter(|f| f.node.iter().any(|n| n.domain == WIRE_DOMAIN))
146        .map(|f| f.name.clone())
147        .collect();
148
149    loop {
150        let mut changed = false;
151        for f in &model.functions {
152            if closure.contains(&f.name) {
153                continue;
154            }
155            if f.node
156                .iter()
157                .any(|n| n.domain == MODULE_CALL_DOMAIN && closure.contains(&n.op_type))
158            {
159                closure.insert(f.name.clone());
160                changed = true;
161            }
162        }
163        if !changed {
164            break;
165        }
166    }
167    closure
168}
169
170/// Functions whose transitive closure is entirely `ai.onnx.*`.
171/// Computed by iterating: a function is pure-ONNX if every node in
172/// its body is either (a) a CALL to another pure-ONNX function or
173/// (b) a direct `ai.onnx` op.
174fn pure_onnx_closure(model: &ModelProto) -> HashSet<String> {
175    let mut closure: HashSet<String> = HashSet::new();
176    loop {
177        let mut changed = false;
178        for f in &model.functions {
179            if closure.contains(&f.name) {
180                continue;
181            }
182            // Empty body counts as pure-ONNX vacuously, but only
183            // matters for synthetic cases — real ONNX functions have
184            // at least one op.
185            let all_ok = !f.node.is_empty()
186                && f.node.iter().all(|n| {
187                    if n.domain == MODULE_CALL_DOMAIN {
188                        closure.contains(&n.op_type)
189                    } else {
190                        n.domain == ONNX_DOMAIN
191                    }
192                });
193            if all_ok {
194                closure.insert(f.name.clone());
195                changed = true;
196            }
197        }
198        if !changed {
199            break;
200        }
201    }
202    closure
203}
204
205/// Reverse-topological order over the inlinable subset. Leaves of
206/// the call graph (functions whose bodies contain no CALLs to other
207/// inlinable functions) come first so each inline operation sees a
208/// body that no longer references other inlinables.
209fn reverse_topo_order(model: &ModelProto, inlinable: &HashSet<String>) -> Vec<String> {
210    let inlinable_idx: HashMap<String, usize> = model
211        .functions
212        .iter()
213        .enumerate()
214        .filter(|(_, f)| inlinable.contains(&f.name))
215        .map(|(i, f)| (f.name.clone(), i))
216        .collect();
217
218    let mut visited: HashSet<String> = HashSet::new();
219    let mut order: Vec<String> = Vec::new();
220
221    fn visit(
222        name: &str,
223        model: &ModelProto,
224        inlinable_idx: &HashMap<String, usize>,
225        visited: &mut HashSet<String>,
226        order: &mut Vec<String>,
227    ) {
228        if !visited.insert(name.to_string()) {
229            return;
230        }
231        let Some(&idx) = inlinable_idx.get(name) else {
232            return;
233        };
234        let f = &model.functions[idx];
235        for node in &f.node {
236            if node.domain == MODULE_CALL_DOMAIN && inlinable_idx.contains_key(&node.op_type) {
237                visit(&node.op_type, model, inlinable_idx, visited, order);
238            }
239        }
240        order.push(name.to_string());
241    }
242
243    let names: Vec<String> = inlinable_idx.keys().cloned().collect();
244    for name in &names {
245        visit(name, model, &inlinable_idx, &mut visited, &mut order);
246    }
247    order
248}
249
250/// Splice one inlined copy of `body` in place of `call`. Intermediate
251/// value names get a unique suffix to avoid collisions across
252/// multiple inlines; formal-input and body-output names are rewritten
253/// to the CALL's actual arg/output names so downstream caller-side
254/// consumers still resolve.
255fn inline_one_call(
256    body: &FunctionProto,
257    call: &NodeProto,
258    next_unique: &mut u64,
259) -> (Vec<NodeProto>, Vec<bb_ir::proto::onnx::ValueInfoProto>) {
260    let unique = *next_unique;
261    *next_unique = next_unique.saturating_add(1);
262
263    let mut rename: HashMap<String, String> = HashMap::new();
264    for (i, formal) in body.input.iter().enumerate() {
265        if let Some(actual) = call.input.get(i) {
266            rename.insert(formal.clone(), actual.clone());
267        }
268    }
269    for (i, body_out) in body.output.iter().enumerate() {
270        if let Some(call_out) = call.output.get(i) {
271            rename.insert(body_out.clone(), call_out.clone());
272        }
273    }
274
275    let mut rename_value = |name: &str| -> String {
276        if name.is_empty() {
277            return String::new();
278        }
279        if let Some(renamed) = rename.get(name) {
280            return renamed.clone();
281        }
282        let fresh = format!("{name}#inl{unique}");
283        rename.insert(name.to_string(), fresh.clone());
284        fresh
285    };
286
287    let mut out: Vec<NodeProto> = Vec::with_capacity(body.node.len());
288    for node in &body.node {
289        let mut cloned = node.clone();
290        for input in cloned.input.iter_mut() {
291            *input = rename_value(input);
292        }
293        for output in cloned.output.iter_mut() {
294            *output = rename_value(output);
295        }
296        out.push(cloned);
297    }
298
299    // Copy value_info entries from the body into the caller under
300    // the renamed names so denotations ride through inlining.
301    let value_info: Vec<bb_ir::proto::onnx::ValueInfoProto> = body
302        .value_info
303        .iter()
304        .filter_map(|vi| {
305            let new_name = rename.get(&vi.name).cloned()?;
306            let mut renamed = vi.clone();
307            renamed.name = new_name;
308            Some(renamed)
309        })
310        .collect();
311
312    (out, value_info)
313}
314