Skip to main content

bb_ops/backends/cpu/
graph_walker.rs

1//! `CpuBackend::execute_graph` — walk a `GraphProto` body (the kind
2//! the compiler's `collapse_backend_subgraphs` pass emits inside
3//! every `BackendSubgraph_*` `FunctionProto`) and run each
4//! NodeProto through the existing kernel dispatch.
5//!
6//! No fancy scheduling — ONNX guarantees `GraphProto.node` is
7//! already topologically ordered, so a linear walk suffices.
8
9use std::collections::HashMap;
10
11use bb_ir::proto::onnx::GraphProto;
12use bb_runtime::atomic::DispatchResult;
13use bb_runtime::bus::OpError;
14use bb_runtime::slot_value::SlotValue;
15
16use crate::backends::cpu::{ops, CpuBackend, CpuTensor};
17
18/// Failures `execute_graph` may surface alongside the kernel-level
19/// `OpError`s already routed through the existing dispatch.
20#[derive(Debug)]
21pub enum BackendError {
22    /// A node input name isn't in the value env. The graph either
23    /// uses a value the caller didn't bind OR an upstream node
24    /// failed to populate one of its declared outputs.
25    MissingInput {
26        /// Name of the missing value.
27        name: String,
28        /// op_type of the consuming node, for diagnostics.
29        op_type: String,
30    },
31
32    /// A node output value isn't a `CpuTensor`. The CpuBackend's
33    /// graph walker only handles f32 tensors; any other `SlotValue`
34    /// kind from a custom kernel rejects.
35    OutputNotTensor {
36        /// op_type that produced the offending value.
37        op_type: String,
38    },
39
40    /// The kernel itself returned an `OpError`. Wraps the error so
41    /// callers see which op surfaced it.
42    KernelFailed {
43        /// op_type that failed.
44        op_type: String,
45        /// Underlying kernel error.
46        source: OpError,
47    },
48
49    /// Bridged failure from the framework's default `Backend`
50    /// walker (malformed graph fed to `Backend::execute`).
51    DefaultWalker(bb_runtime::contracts::backend_default_walk::BackendWalkError),
52}
53
54impl From<bb_runtime::contracts::backend_default_walk::BackendWalkError> for BackendError {
55    fn from(value: bb_runtime::contracts::backend_default_walk::BackendWalkError) -> Self {
56        Self::DefaultWalker(value)
57    }
58}
59
60impl std::fmt::Display for BackendError {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::MissingInput { name, op_type } => {
64                write!(f, "execute_graph: input `{name}` missing for `{op_type}`",)
65            }
66            Self::OutputNotTensor { op_type } => write!(
67                f,
68                "execute_graph: `{op_type}` produced a non-CpuTensor output",
69            ),
70            Self::KernelFailed { op_type, source } => {
71                write!(f, "execute_graph: `{op_type}` kernel failed: {source}",)
72            }
73            Self::DefaultWalker(e) => write!(f, "{e}"),
74        }
75    }
76}
77
78impl std::error::Error for BackendError {}
79
80/// Run every `NodeProto` in `graph.node` in order, threading a
81/// `HashMap<String, CpuTensor>` value env. Returns the subset of
82/// `env` named in `graph.output`.
83///
84/// Pure over the `(graph, inputs)` pair — no engine context
85/// required. Each node's `node.attribute` is the kernel's
86/// attribute source; the dispatch path is identical to the
87/// `BackendRuntime::dispatch_atomic` per-op path except that
88/// attributes come from the NodeProto directly instead of through
89/// `RuntimeResourceRef::current_node_attributes`.
90pub fn execute_graph(
91    backend: &CpuBackend,
92    graph: &GraphProto,
93    inputs: HashMap<String, CpuTensor>,
94) -> Result<HashMap<String, CpuTensor>, BackendError> {
95    let mut env: HashMap<String, CpuTensor> = inputs;
96
97    for node in &graph.node {
98        let mut input_refs: Vec<(&str, &dyn SlotValue)> = Vec::with_capacity(node.input.len());
99        for name in &node.input {
100            if name.is_empty() {
101                continue;
102            }
103            let tensor = env.get(name).ok_or_else(|| BackendError::MissingInput {
104                name: name.clone(),
105                op_type: node.op_type.clone(),
106            })?;
107            input_refs.push((name.as_str(), tensor as &dyn SlotValue));
108        }
109
110        let result =
111            ops::dispatch(backend, &node.op_type, &input_refs, &node.attribute).map_err(|e| {
112                BackendError::KernelFailed {
113                    op_type: node.op_type.clone(),
114                    source: e,
115                }
116            })?;
117
118        let outputs = match result {
119            DispatchResult::Immediate(outs) => outs,
120            DispatchResult::Async(_) => {
121                return Err(BackendError::KernelFailed {
122                    op_type: node.op_type.clone(),
123                    source: OpError {
124                        detail: format!(
125                            "{op}: async dispatch unsupported inside execute_graph",
126                            op = node.op_type,
127                        ),
128                        ..Default::default()
129                    },
130                });
131            }
132        };
133
134        // Map each (kernel-named) output positionally to the
135        // NodeProto's `node.output[i]` name. Kernels label outputs
136        // `"C"` / `"Y"` / `"out_0"` etc., which don't have to match
137        // the consumer-side value name the graph references.
138        for (i, (_kernel_name, boxed)) in outputs.into_iter().enumerate() {
139            let Some(graph_name) = node.output.get(i) else {
140                // Kernel produced more outputs than the graph
141                // declares — drop the extra silently (the consumer
142                // doesn't reference it).
143                continue;
144            };
145            if graph_name.is_empty() {
146                continue;
147            }
148            // Consume the boxed kernel output into the env without
149            // cloning. `into_any_boxed` repackages `Box<dyn SlotValue>`
150            // as `Box<dyn Any>` so `Box::downcast` lands the concrete
151            // tensor by move.
152            let any = boxed.into_any_boxed();
153            let tensor: Box<CpuTensor> =
154                any.downcast::<CpuTensor>()
155                    .map_err(|_| BackendError::OutputNotTensor {
156                        op_type: node.op_type.clone(),
157                    })?;
158            env.insert(graph_name.clone(), *tensor);
159        }
160    }
161
162    // Return the subset named in graph.output. Missing names drop
163    // silently — callers explicitly ask for them by walking
164    // `graph.output` themselves if they need full coverage.
165    let mut out: HashMap<String, CpuTensor> = HashMap::new();
166    for vi in &graph.output {
167        if let Some(t) = env.remove(&vi.name) {
168            out.insert(vi.name.clone(), t);
169        }
170    }
171    Ok(out)
172}
173