bb_ops/backends/cpu/graph_walker.rs
1//! `CpuBackend::execute_graph` — walk a `GraphProto` body (the kind
2//! the compiler's `collapse_backend_subgraphs` pass emits inside
3//! every `BackendSubgraph_*` `FunctionProto`) and run each
4//! NodeProto through the existing kernel dispatch.
5//!
6//! No fancy scheduling — ONNX guarantees `GraphProto.node` is
7//! already topologically ordered, so a linear walk suffices.
8
9use std::collections::HashMap;
10
11use bb_ir::proto::onnx::GraphProto;
12use bb_runtime::atomic::DispatchResult;
13use bb_runtime::bus::OpError;
14use bb_runtime::slot_value::SlotValue;
15
16use crate::backends::cpu::{ops, CpuBackend, CpuTensor};
17
18/// Failures `execute_graph` may surface alongside the kernel-level
19/// `OpError`s already routed through the existing dispatch.
20#[derive(Debug)]
21pub enum BackendError {
22 /// A node input name isn't in the value env. The graph either
23 /// uses a value the caller didn't bind OR an upstream node
24 /// failed to populate one of its declared outputs.
25 MissingInput {
26 /// Name of the missing value.
27 name: String,
28 /// op_type of the consuming node, for diagnostics.
29 op_type: String,
30 },
31
32 /// A node output value isn't a `CpuTensor`. The CpuBackend's
33 /// graph walker only handles f32 tensors; any other `SlotValue`
34 /// kind from a custom kernel rejects.
35 OutputNotTensor {
36 /// op_type that produced the offending value.
37 op_type: String,
38 },
39
40 /// The kernel itself returned an `OpError`. Wraps the error so
41 /// callers see which op surfaced it.
42 KernelFailed {
43 /// op_type that failed.
44 op_type: String,
45 /// Underlying kernel error.
46 source: OpError,
47 },
48
49 /// Bridged failure from the framework's default `Backend`
50 /// walker (malformed graph fed to `Backend::execute`).
51 DefaultWalker(bb_runtime::contracts::backend_default_walk::BackendWalkError),
52}
53
54impl From<bb_runtime::contracts::backend_default_walk::BackendWalkError> for BackendError {
55 fn from(value: bb_runtime::contracts::backend_default_walk::BackendWalkError) -> Self {
56 Self::DefaultWalker(value)
57 }
58}
59
60impl std::fmt::Display for BackendError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::MissingInput { name, op_type } => {
64 write!(f, "execute_graph: input `{name}` missing for `{op_type}`",)
65 }
66 Self::OutputNotTensor { op_type } => write!(
67 f,
68 "execute_graph: `{op_type}` produced a non-CpuTensor output",
69 ),
70 Self::KernelFailed { op_type, source } => {
71 write!(f, "execute_graph: `{op_type}` kernel failed: {source}",)
72 }
73 Self::DefaultWalker(e) => write!(f, "{e}"),
74 }
75 }
76}
77
78impl std::error::Error for BackendError {}
79
80/// Run every `NodeProto` in `graph.node` in order, threading a
81/// `HashMap<String, CpuTensor>` value env. Returns the subset of
82/// `env` named in `graph.output`.
83///
84/// Pure over the `(graph, inputs)` pair — no engine context
85/// required. Each node's `node.attribute` is the kernel's
86/// attribute source; the dispatch path is identical to the
87/// `BackendRuntime::dispatch_atomic` per-op path except that
88/// attributes come from the NodeProto directly instead of through
89/// `RuntimeResourceRef::current_node_attributes`.
90pub fn execute_graph(
91 backend: &CpuBackend,
92 graph: &GraphProto,
93 inputs: HashMap<String, CpuTensor>,
94) -> Result<HashMap<String, CpuTensor>, BackendError> {
95 let mut env: HashMap<String, CpuTensor> = inputs;
96
97 for node in &graph.node {
98 let mut input_refs: Vec<(&str, &dyn SlotValue)> = Vec::with_capacity(node.input.len());
99 for name in &node.input {
100 if name.is_empty() {
101 continue;
102 }
103 let tensor = env.get(name).ok_or_else(|| BackendError::MissingInput {
104 name: name.clone(),
105 op_type: node.op_type.clone(),
106 })?;
107 input_refs.push((name.as_str(), tensor as &dyn SlotValue));
108 }
109
110 let result =
111 ops::dispatch(backend, &node.op_type, &input_refs, &node.attribute).map_err(|e| {
112 BackendError::KernelFailed {
113 op_type: node.op_type.clone(),
114 source: e,
115 }
116 })?;
117
118 let outputs = match result {
119 DispatchResult::Immediate(outs) => outs,
120 DispatchResult::Async(_) => {
121 return Err(BackendError::KernelFailed {
122 op_type: node.op_type.clone(),
123 source: OpError {
124 detail: format!(
125 "{op}: async dispatch unsupported inside execute_graph",
126 op = node.op_type,
127 ),
128 ..Default::default()
129 },
130 });
131 }
132 };
133
134 // Map each (kernel-named) output positionally to the
135 // NodeProto's `node.output[i]` name. Kernels label outputs
136 // `"C"` / `"Y"` / `"out_0"` etc., which don't have to match
137 // the consumer-side value name the graph references.
138 for (i, (_kernel_name, boxed)) in outputs.into_iter().enumerate() {
139 let Some(graph_name) = node.output.get(i) else {
140 // Kernel produced more outputs than the graph
141 // declares — drop the extra silently (the consumer
142 // doesn't reference it).
143 continue;
144 };
145 if graph_name.is_empty() {
146 continue;
147 }
148 // Consume the boxed kernel output into the env without
149 // cloning. `into_any_boxed` repackages `Box<dyn SlotValue>`
150 // as `Box<dyn Any>` so `Box::downcast` lands the concrete
151 // tensor by move.
152 let any = boxed.into_any_boxed();
153 let tensor: Box<CpuTensor> =
154 any.downcast::<CpuTensor>()
155 .map_err(|_| BackendError::OutputNotTensor {
156 op_type: node.op_type.clone(),
157 })?;
158 env.insert(graph_name.clone(), *tensor);
159 }
160 }
161
162 // Return the subset named in graph.output. Missing names drop
163 // silently — callers explicitly ask for them by walking
164 // `graph.output` themselves if they need full coverage.
165 let mut out: HashMap<String, CpuTensor> = HashMap::new();
166 for vi in &graph.output {
167 if let Some(t) = env.remove(&vi.name) {
168 out.insert(vi.name.clone(), t);
169 }
170 }
171 Ok(out)
172}
173